更新message传递

This commit is contained in:
yuruo 2025-03-04 14:32:24 +08:00
parent 9338b8411d
commit 9470235357
7 changed files with 9 additions and 539 deletions

View File

@ -1,160 +0,0 @@
"""
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""
import asyncio
import platform
from collections.abc import Callable
from datetime import datetime
from enum import StrEnum
from typing import Any, cast
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types import (
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
BetaToolResultBlockParam,
)
from anthropic.types import TextBlock
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
from gradio_ui.tools import ComputerTool, ToolCollection, ToolResult
from PIL import Image
from io import BytesIO
import gradio as gr
from typing import Dict
BETA_FLAG = "computer-use-2024-10-22"
class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilizing a Windows system with internet access.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
</SYSTEM_CAPABILITY>
"""
class AnthropicActor:
def __init__(
self,
model: str,
api_key: str,
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
max_tokens: int = 4096,
only_n_most_recent_images: int | None = None,
print_usage: bool = True,
):
self.model = model
self.api_key = api_key
self.api_response_callback = api_response_callback
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
self.tool_collection = ToolCollection(ComputerTool())
self.system = SYSTEM_PROMPT
self.total_token_usage = 0
self.total_cost = 0
self.print_usage = print_usage
# Instantiate the appropriate API client based on the provider
if provider == APIProvider.ANTHROPIC:
self.client = Anthropic(api_key=api_key)
elif provider == APIProvider.VERTEX:
self.client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
self.client = AnthropicBedrock()
def __call__(
self,
*,
messages: list[BetaMessageParam]
):
"""
Generate a response given history messages.
"""
if self.only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, self.only_n_most_recent_images)
# Call the API synchronously
raw_response = self.client.beta.messages.with_raw_response.create(
max_tokens=self.max_tokens,
messages=messages,
model=self.model,
system=self.system,
tools=self.tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
)
self.api_response_callback(cast(APIResponse[BetaMessage], raw_response))
response = raw_response.parse()
print(f"AnthropicActor response: {response}")
self.total_token_usage += response.usage.input_tokens + response.usage.output_tokens
self.total_cost += (response.usage.input_tokens * 3 / 1000000 + response.usage.output_tokens * 15 / 1000000)
if self.print_usage:
print(f"Claude total token usage so far: {self.total_token_usage}, total cost so far: $USD{self.total_cost}")
return response
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
"""
if images_to_keep is None:
return messages
tool_result_blocks = cast(
list[ToolResultBlockParam],
[
item
for message in messages
for item in (
message["content"] if isinstance(message["content"], list) else []
)
if isinstance(item, dict) and item.get("type") == "tool_result"
],
)
total_images = sum(
1
for tool_result in tool_result_blocks
for content in tool_result.get("content", [])
if isinstance(content, dict) and content.get("type") == "image"
)
images_to_remove = total_images - images_to_keep
# for better cache behavior, we want to remove in chunks
images_to_remove -= images_to_remove % min_removal_threshold
for tool_result in tool_result_blocks:
if isinstance(tool_result.get("content"), list):
new_content = []
for content in tool_result.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_content.append(content)
tool_result["content"] = new_content

View File

@ -45,7 +45,7 @@ class VLMAgent:
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
self.output_callback = output_callback
self.model = model
self.print_usage = print_usage
self.total_token_usage = 0
self.total_cost = 0

View File

@ -1,329 +0,0 @@
import json
from collections.abc import Callable
from typing import cast, Callable
import uuid
from PIL import Image, ImageDraw
import base64
from io import BytesIO
from gradio_ui.agent.llm_utils.oaiclient import run_oai_interleaved
from gradio_ui.agent.llm_utils.groqclient import run_groq_interleaved
from gradio_ui.agent.llm_utils.utils import is_image_path
import time
import re
OUTPUT_DIR = "./tmp/outputs"
def extract_data(input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
pattern = f"```{data_type}" + r"(.*?)(```|$)"
# Extract content
# re.DOTALL allows '.' to match newlines as well
matches = re.findall(pattern, input_string, re.DOTALL)
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
return matches[0][0].strip() if matches else input_string
class VLMAgent:
def __init__(
self,
model: str,
base_url: Optional(str),
api_key: str,
output_callback: Callable,
api_response_callback: Callable,
max_tokens: int = 4096,
only_n_most_recent_images: int | None = None,
print_usage: bool = True,
):
self.api_key = api_key
self.model = model
self.base_url = base_url # Currently could be "", we should consider having None
self.api_response_callback = api_response_callback
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
self.output_callback = output_callback
self.print_usage = print_usage
self.total_token_usage = 0
self.total_cost = 0
self.step_count = 0
self.system = ''
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
self.step_count += 1
image_base64 = parsed_screen['original_screenshot_base64']
latency_omniparser = parsed_screen['latency']
self.output_callback(f'-- Step {self.step_count}: --', sender="bot")
screen_info = str(parsed_screen['screen_info'])
screenshot_uuid = parsed_screen['screenshot_uuid']
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
boxids_and_labels = parsed_screen["screen_info"]
system = self._get_system_prompt(boxids_and_labels)
# drop looping actions msg, byte image etc
planner_messages = messages
_remove_som_images(planner_messages)
_maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
if isinstance(planner_messages[-1], dict):
if not isinstance(planner_messages[-1]["content"], list):
planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
start = time.time()
# OAI. What's the difference
vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=min(2048, self.max_tokens), # Only Qwen has a cap of 2048
provider_base_url=self.base_url,
temperature=0,
)
if "r1" in self.model: # or if base_url = ""
vlm_response, token_usage = run_groq_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
)
print(f"token usage: {token_usage}")
self.total_token_usage += token_usage
if 'gpt' in self.model:
self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/
elif 'o1' in self.model:
self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/
elif 'o3-mini' in self.model:
self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/
elif 'qwen' in self.model:
self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
elif 'r1' in self.model:
self.total_cost += (token_usage * 0.99 / 1000000)
latency_vlm = time.time() - start
self.output_callback(f"LLM: {latency_vlm:.2f}s, OmniParser: {latency_omniparser:.2f}s", sender="bot")
print(f"{vlm_response}")
if self.print_usage:
print(f"Total token so far: {self.total_token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
vlm_response_json = extract_data(vlm_response, "json")
vlm_response_json = json.loads(vlm_response_json)
img_to_show_base64 = parsed_screen["som_image_base64"]
if "Box ID" in vlm_response_json:
try:
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"]
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
img_to_show_data = base64.b64decode(img_to_show_base64)
img_to_show = Image.open(BytesIO(img_to_show_data))
draw = ImageDraw.Draw(img_to_show)
x, y = vlm_response_json["box_centroid_coordinate"]
radius = 10
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
except:
print(f"Error parsing: {vlm_response_json}")
pass
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
self.output_callback(
f'<details>'
f' <summary>Parsed Screen elemetns by OmniParser</summary>'
f' <pre>{screen_info}</pre>'
f'</details>',
sender="bot"
)
vlm_plan_str = ""
for key, value in vlm_response_json.items():
if key == "Reasoning":
vlm_plan_str += f'{value}'
else:
vlm_plan_str += f'\n{key}: {value}'
# construct the response so that anthropicExcutor can execute the tool
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
name='computer', type='tool_use')
response_content.append(move_cursor_block)
if vlm_response_json["Next Action"] == "None":
print("Task paused/completed.")
elif vlm_response_json["Next Action"] == "type":
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["Next Action"], 'text': vlm_response_json["value"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
else:
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["Next Action"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
return response_message, vlm_response_json
def _api_response_callback(self, response: APIResponse):
self.api_response_callback(response)
def _get_system_prompt(self, screen_info: str = ""):
main_section = f"""
You are using a Windows device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You can only interact with the desktop GUI (no terminal or application menu access).
You may be given some history plan and actions, this is the response from the previous loop.
You should carefully consider your plan base on the task, screenshot, and history actions.
Here is the list of all detected bounding boxes by IDs on the screen and their description:{screen_info}
Your available "Next Action" only include:
- type: types a string of text.
- left_click: move mouse to box id and left clicks.
- right_click: move mouse to box id and right clicks.
- double_click: move mouse to box id and double clicks.
- hover: move mouse to box id.
- scroll_up: scrolls the screen up to view previous content.
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
- wait: waits for 1 second for the device to load or respond.
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
Output format:
```json
{{
"Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
"Box ID": n,
"value": "xxx" # only provide value field if the action is type, else don't include value key
}}
```
One Example:
```json
{{
"Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
"Next Action": "left_click",
"Box ID": m
}}
```
Another Example:
```json
{{
"Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
"Next Action": "type",
"Box ID": n,
"value": "Apple watch"
}}
```
Another Example:
```json
{{
"Reasoning": "The current screen does not show 'submit' button, I need to scroll down to see if the button is available.",
"Next Action": "scroll_down",
}}
```
IMPORTANT NOTES:
1. You should only give a single action at a time.
"""
thinking_model = "r1" in self.model
if not thinking_model:
main_section += """
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
"""
else:
main_section += """
2. In <think> XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In <output> XML tags put the next action prediction JSON.
"""
main_section += """
3. Attach the next action prediction in the "Next Action".
4. You should not include other actions, such as keyboard shortcuts.
5. When the task is completed, don't complete additional actions. You should say "Next Action": "None" in the json field.
6. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
7. avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
8. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Next Action": "None" in the json field.
"""
return main_section
def _remove_som_images(messages):
for msg in messages:
msg_content = msg["content"]
if isinstance(msg_content, list):
msg["content"] = [
cnt for cnt in msg_content
if not (isinstance(cnt, str) and 'som' in cnt and is_image_path(cnt))
]
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place
"""
if images_to_keep is None:
return messages
total_images = 0
for msg in messages:
for cnt in msg.get("content", []):
if isinstance(cnt, str) and is_image_path(cnt):
total_images += 1
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
for content in cnt.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
total_images += 1
images_to_remove = total_images - images_to_keep
for msg in messages:
msg_content = msg["content"]
if isinstance(msg_content, list):
new_content = []
for cnt in msg_content:
# Remove images from SOM or screenshot as needed
if isinstance(cnt, str) and is_image_path(cnt):
if images_to_remove > 0:
images_to_remove -= 1
continue
# VLM shouldn't use anthropic screenshot tool so shouldn't have these but in case it does, remove as needed
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
new_tool_result_content = []
for tool_result_entry in cnt.get("content", []):
if isinstance(tool_result_entry, dict) and tool_result_entry.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_tool_result_content.append(tool_result_entry)
cnt["content"] = new_tool_result_content
# Append fixed content to current message's content list
new_content.append(cnt)
msg["content"] = new_content

View File

@ -15,7 +15,6 @@ from anthropic.types import TextBlock
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
from anthropic.types.tool_use_block import ToolUseBlock
from gradio_ui.loop import (
APIProvider,
sampling_loop_sync,
)
from gradio_ui.tools import ToolResult
@ -63,32 +62,12 @@ def setup_state(state):
if 'stop' not in state:
state['stop'] = False
if 'base_url' not in state:
state['base_url'] = ""
state['base_url'] = "https://api.openai-next.com/v1"
async def main(state):
"""Render loop for Gradio"""
setup_state(state)
return "Setup completed"
# 删除整个 validate_auth 函数
def validate_auth(provider: APIProvider, api_key: str | None):
if provider == APIProvider.ANTHROPIC:
if not api_key:
return "Enter your Anthropic API key to continue."
if provider == APIProvider.BEDROCK:
import boto3
if not boto3.Session().get_credentials():
return "You must have AWS credentials set up to use the Bedrock API."
if provider == APIProvider.VERTEX:
import google.auth
from google.auth.exceptions import DefaultCredentialsError
if not os.environ.get("CLOUD_ML_REGION"):
return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
try:
google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
except DefaultCredentialsError:
return "Your google cloud credentials are not set up correctly."
def load_from_storage(filename: str) -> str | None:
"""Load data from a file in the storage directory."""
@ -188,15 +167,13 @@ def process_input(user_input, state):
state["messages"].append(
{
"role": Sender.USER,
"content": user_input,
"content": [TextBlock(type="text", text=user_input)],
}
)
# Append the user's message to chatbot_messages with None for the assistant's reply
state['chatbot_messages'].append({"role": "user", "content": user_input}) # 确保格式正确
state['chatbot_messages'].append((user_input, None)) # 确保格式正确
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
print(state)
# Run sampling_loop_sync with the chatbot_output_callback
for loop_msg in sampling_loop_sync(
model=state["model"],

View File

@ -20,22 +20,6 @@ from gradio_ui.agent.anthropic_agent import AnthropicActor
from gradio_ui.agent.vlm_agent import VLMAgent
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
BETA_FLAG = "computer-use-2024-10-22"
class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"
OPENAI = "openai"
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
APIProvider.OPENAI: "gpt-4o",
}
def sampling_loop_sync(
*,
model: str,
@ -75,9 +59,7 @@ def sampling_loop_sync(
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages

View File

@ -8,9 +8,9 @@ import time
import torch
def run():
try:
print(torch.cuda.is_available()) # 应该返回True
print(torch.cuda.device_count()) # 应该至少返回1
print(torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
except Exception:
print("显卡驱动不适配请根据readme安装合适版本的 torch")
@ -26,7 +26,7 @@ def run():
try:
# 下载权重文件
download_weights.download()
print("启动Omniserver服务中20s左右请耐心等待")
print("启动Omniserver服务中40s左右请耐心等待")
# 启动 Gradio UI
# 等待 server_process 打印出 "Started server process"
while True:

View File

@ -11,7 +11,7 @@ class Omniparser(object):
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
print('Omniparser initialized!!!')
print('Omniparser initialized!')
def parse(self, image_base64: str):
image_bytes = base64.b64decode(image_base64)