add custom agent

This commit is contained in:
warmshao
2025-01-02 21:07:20 +08:00
parent 6c07ec2603
commit 863e865446
5 changed files with 461 additions and 20 deletions

View File

@@ -5,11 +5,16 @@
# @FileName: custom_agent.py
import asyncio
import base64
import io
import json
import logging
import os
import pdb
import textwrap
import time
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any, Optional, Type, TypeVar
@@ -20,10 +25,12 @@ from langchain_core.messages import (
SystemMessage,
)
from openai import RateLimitError
from PIL import Image, ImageDraw, ImageFont
from pydantic import BaseModel, ValidationError
from browser_use.agent.message_manager.service import MessageManager
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
from browser_use.agent.service import Agent
from browser_use.agent.views import (
ActionResult,
AgentError,
@@ -32,21 +39,76 @@ from browser_use.agent.views import (
AgentOutput,
AgentStepInfo,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserState, BrowserStateHistory
from browser_use.controller.registry.views import ActionModel
from browser_use.controller.service import Controller
from browser_use.dom.history_tree_processor.service import (
DOMHistoryElement,
HistoryTreeProcessor,
)
from browser_use.telemetry.service import ProductTelemetry
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
AgentRunTelemetryEvent,
AgentStepErrorTelemetryEvent,
)
from browser_use.agent.service import Agent
from browser_use.utils import time_execution_async
from .custom_views import CustomAgentOutput
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
from .custom_massage_manager import CustomMassageManager
logger = logging.getLogger(__name__)
class CustomAgent(Agent):
def __init__(
self,
task: str,
llm: BaseChatModel,
add_infos: str = '',
browser: Browser | None = None,
browser_context: BrowserContext | None = None,
controller: Controller = Controller(),
use_vision: bool = True,
save_conversation_path: Optional[str] = None,
max_failures: int = 5,
retry_delay: int = 10,
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
max_input_tokens: int = 128000,
validate_output: bool = False,
include_attributes: list[str] = [
'title',
'type',
'name',
'role',
'tabindex',
'aria-label',
'placeholder',
'value',
'alt',
'aria-expanded',
],
max_error_length: int = 400,
max_actions_per_step: int = 10,
):
super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path,
max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output,
include_attributes, max_error_length, max_actions_per_step)
self.add_infos = add_infos
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
action_descriptions=self.controller.registry.get_prompt_description(),
system_prompt_class=self.system_prompt_class,
max_input_tokens=self.max_input_tokens,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
max_actions_per_step=self.max_actions_per_step,
)
def _setup_action_models(self) -> None:
"""Setup dynamic action models from controller's registry"""
# Get the dynamic action model from controller's registry
@@ -56,23 +118,42 @@ class CustomAgent(Agent):
def _log_response(self, response: CustomAgentOutput) -> None:
"""Log the model's response"""
if 'Success' in response.current_state.evaluation_previous_goal:
emoji = '👍'
elif 'Failed' in response.current_state.evaluation_previous_goal:
emoji = ''
if 'Success' in response.current_state.prev_action_evaluation:
emoji = ''
elif 'Failed' in response.current_state.prev_action_evaluation:
emoji = ''
else:
emoji = '🤷'
logger.info(f'{emoji} Eval: {response.current_state.evaluation_previous_goal}')
logger.info(f'🧠 Memory: {response.current_state.memory}')
logger.info(f'🎯 Next goal: {response.current_state.next_goal}')
logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}')
logger.info(f'🧠 Memory: {response.current_state.import_contents}')
logger.info(f'⏳ Task Progress: {response.current_state.completed_contents}')
logger.info(f'🤔 Thought: {response.current_state.thought}')
logger.info(f'🎯 Summary: {response.current_state.summary}')
for i, action in enumerate(response.action):
logger.info(
f'🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}'
)
def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None):
"""
update step info
"""
if step_info is None:
return
step_info.step_number += 1
import_contents = model_output.current_state.import_contents
if import_contents and 'None' not in import_contents and import_contents not in step_info.memory:
step_info.memory += import_contents + '\n'
completed_contents = model_output.current_state.completed_contents
if completed_contents and 'None' not in completed_contents:
step_info.task_progress = completed_contents
@time_execution_async('--step')
async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
"""Execute one step of the task"""
logger.info(f'\n📍 Step {self.n_steps}')
state = None
@@ -84,6 +165,7 @@ class CustomAgent(Agent):
self.message_manager.add_state_message(state, self._last_result, step_info)
input_messages = self.message_manager.get_messages()
model_output = await self.get_next_action(input_messages)
self.update_step_info(model_output, step_info)
self._save_conversation(input_messages, model_output)
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
self.message_manager.add_model_output(model_output)
@@ -115,3 +197,87 @@ class CustomAgent(Agent):
)
if state:
self._make_history_item(model_output, state, result)
def _make_history_item(
self,
model_output: CustomAgentOutput | None,
state: BrowserState,
result: list[ActionResult],
) -> None:
"""Create and store history item"""
interacted_element = None
len_result = len(result)
if model_output:
interacted_elements = AgentHistory.get_interacted_element(
model_output, state.selector_map
)
else:
interacted_elements = [None]
state_history = BrowserStateHistory(
url=state.url,
title=state.title,
tabs=state.tabs,
interacted_element=interacted_elements,
screenshot=state.screenshot,
)
history_item = AgentHistory(model_output=model_output, result=result, state=state_history)
self.history.history.append(history_item)
async def run(self, max_steps: int = 100) -> AgentHistoryList:
"""Execute the task with maximum number of steps"""
try:
logger.info(f'🚀 Starting task: {self.task}')
self.telemetry.capture(
AgentRunTelemetryEvent(
agent_id=self.agent_id,
task=self.task,
)
)
step_info = CustomAgentStepInfo(task=self.task,
add_infos=self.add_infos,
step_number=1,
max_steps=max_steps,
memory='',
task_progress=''
)
for step in range(max_steps):
if self._too_many_failures():
break
await self.step(step_info)
if self.history.is_done():
if (
self.validate_output and step < max_steps - 1
): # if last step, we dont need to validate
if not await self._validate_output():
continue
logger.info('✅ Task completed successfully')
break
else:
logger.info('❌ Failed to complete task in maximum steps')
return self.history
finally:
self.telemetry.capture(
AgentEndTelemetryEvent(
agent_id=self.agent_id,
task=self.task,
success=self.history.is_done(),
steps=len(self.history.history),
)
)
if not self.injected_browser_context:
await self.browser_context.close()
if not self.injected_browser and self.browser:
await self.browser.close()

View File

@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# @Time : 2025/1/2
# @Author : wenshao
# @ProjectName: browser-use-webui
# @FileName: custom_massage_manager.py
from __future__ import annotations
import logging
from datetime import datetime
from typing import List, Optional, Type
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
)
from langchain_openai import ChatOpenAI
from browser_use.agent.message_manager.views import MessageHistory, MessageMetadata
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo
from browser_use.browser.views import BrowserState
from browser_use.agent.message_manager.service import MessageManager
from .custom_prompts import CustomAgentMessagePrompt
logger = logging.getLogger(__name__)
class CustomMassageManager(MessageManager):
def __init__(
self,
llm: BaseChatModel,
task: str,
action_descriptions: str,
system_prompt_class: Type[SystemPrompt],
max_input_tokens: int = 128000,
estimated_tokens_per_character: int = 3,
image_tokens: int = 800,
include_attributes: list[str] = [],
max_error_length: int = 400,
max_actions_per_step: int = 10,
):
super().__init__(llm, task, action_descriptions, system_prompt_class, max_input_tokens,
estimated_tokens_per_character, image_tokens, include_attributes, max_error_length,
max_actions_per_step)
# Move Task info to state_message
self.history = MessageHistory()
self._add_message_with_tokens(self.system_prompt)
def add_state_message(
self,
state: BrowserState,
result: Optional[List[ActionResult]] = None,
step_info: Optional[AgentStepInfo] = None,
) -> None:
"""Add browser state as human message"""
# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content=str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
msg = HumanMessage(content=str(r.error)[-self.max_error_length:])
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again
# otherwise add state message and result to next message (which will not stay in memory)
state_message = CustomAgentMessagePrompt(
state,
result,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
step_info=step_info,
).get_user_message()
self._add_message_with_tokens(state_message)

View File

@@ -11,7 +11,195 @@ from langchain_core.messages import HumanMessage, SystemMessage
from browser_use.agent.views import ActionResult, AgentStepInfo
from browser_use.browser.views import BrowserState
from browser_use.agent.prompts import SystemPrompt
from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt
from .custom_views import CustomAgentStepInfo
class CustomSystemPrompt(SystemPrompt):
pass
def important_rules(self) -> str:
"""
Returns the important rules for the agent.
"""
text = """
1. RESPONSE FORMAT: You must ALWAYS respond with valid JSON in this exact format:
{
"current_state": {
"prev_action_evaluation": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Ignore the action result. The website is the ground truth. Also mention if something unexpected happened like new suggestions in an input field. Shortly state why/why not. Note that the result you output must be consistent with the reasoning you output afterwards. If you consider it to be 'Failed,' you should reflect on this during your thought.",
"import_contents": "Please think about whether there is any content closely related to user\'s instruction or task on the current page? If there is, please output the content. If not, please output \"None\".",
"completed_contents": "Update the task progress. Don\'t output the purpose of any operation. Completed contents is a general summary of the current contents that have been completed. Just summarize the contents that have been actually completed based on the current and the history operations. Please list each completed item individually, such as: 1. Input username. 2. Input Password. 3. Click confirm button",
"thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If the output of prev_action_evaluation is 'Failed', please reflect and output your reflection here. If you think you have entered the wrong page, consider to go back to the previous page in next action.",
"summary": "Please generate a brief natural language description for the operation in next Actions based on your Thought."
},
"action": [
{
"action_name": {
// action-specific parameters
}
},
// ... more actions in sequence
]
}
2. ACTIONS: You can specify multiple actions to be executed in sequence.
Common action sequences:
- Form filling: [
{"input_text": {"index": 1, "text": "username"}},
{"input_text": {"index": 2, "text": "password"}},
{"click_element": {"index": 3}}
]
- Navigation and extraction: [
{"open_new_tab": {}},
{"go_to_url": {"url": "https://example.com"}},
{"extract_page_content": {}}
]
3. ELEMENT INTERACTION:
- Only use indexes that exist in the provided element list
- Each element has a unique index number (e.g., "33[:]<button>")
- Elements marked with "_[:]" are non-interactive (for context only)
4. NAVIGATION & ERROR HANDLING:
- If no suitable elements exist, use other functions to complete the task
- If stuck, try alternative approaches
- Handle popups/cookies by accepting or closing them
- Use scroll to find elements you are looking for
5. TASK COMPLETION:
- If you think all the requirements of user\'s instruction have been completed and no further operation is required, output the done action to terminate the operation process.
- Don't hallucinate actions.
- If the task requires specific information - make sure to include everything in the done function. This is what the user will see.
- If you are running out of steps (current step), think about speeding it up, and ALWAYS use the done action as the last action.
6. VISUAL CONTEXT:
- When an image is provided, use it to understand the page layout
- Bounding boxes with labels correspond to element indexes
- Each bounding box and its label have the same color
- Most often the label is inside the bounding box, on the top right
- Visual context helps verify element locations and relationships
- sometimes labels overlap, so use the context to verify the correct element
7. Form filling:
- If you fill a input field and your action sequence is interrupted, most often a list with suggestions poped up under the field and you need to first select the right element from the suggestion list.
8. ACTION SEQUENCING:
- Actions are executed in the order they appear in the list
- Each action should logically follow from the previous one
- If the page changes after an action, the sequence is interrupted and you get the new state.
- If content only disappears the sequence continues.
- Only provide the action sequence until you think the page will change.
- Try to be efficient, e.g. fill forms at once, or chain actions where nothing changes on the page like saving, extracting, checkboxes...
- only use multiple actions if it makes sense.
"""
text += f' - use maximum {self.max_actions_per_step} actions per sequence'
return text
def input_format(self) -> str:
return """
INPUT STRUCTURE:
1. Task: The user\'s instructions you need to complete.
2. Hints(Optional): Some hints to help you complete the user\'s instructions.
3. Memory: Important contents are recorded during historical operations for use in subsequent operations.
4. Task Progress: Up to the current page, the content you have completed can be understood as the progress of the task.
5. Current URL: The webpage you're currently on
6. Available Tabs: List of open browser tabs
7. Interactive Elements: List in the format:
index[:]<element_type>element_text</element_type>
- index: Numeric identifier for interaction
- element_type: HTML element type (button, input, etc.)
- element_text: Visible text or element description
Example:
33[:]<button>Submit Form</button>
_[:] Non-interactive text
Notes:
- Only elements with numeric indexes are interactive
- _[:] elements provide context but cannot be interacted with
"""
def get_system_message(self) -> SystemMessage:
"""
Get the system prompt for the agent.
Returns:
str: Formatted system prompt
"""
time_str = self.current_date.strftime('%Y-%m-%d %H:%M')
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
1. Analyze the provided webpage elements and structure
2. Plan a sequence of actions to accomplish the given task
3. Respond with valid JSON containing your action sequence and state assessment
Current date and time: {time_str}
{self.input_format()}
{self.important_rules()}
Functions:
{self.default_action_description}
Remember: Your responses must be valid JSON matching the specified format. Each action in the sequence must be valid."""
return SystemMessage(content=AGENT_PROMPT)
class CustomAgentMessagePrompt:
def __init__(
self,
state: BrowserState,
result: Optional[List[ActionResult]] = None,
include_attributes: list[str] = [],
max_error_length: int = 400,
step_info: Optional[CustomAgentStepInfo] = None,
):
self.state = state
self.result = result
self.max_error_length = max_error_length
self.include_attributes = include_attributes
self.step_info = step_info
def get_user_message(self) -> HumanMessage:
state_description = f"""
1. Task: {self.step_info.task}
2. Hints(Optional):
{self.step_info.add_infos}
3. Memory:
{self.step_info.memory}
4. Task Progress:
{self.step_info.task_progress}
5. Current url: {self.state.url}
6. Available tabs:
{self.state.tabs}
7. Interactive elements:
{self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)}
"""
if self.result:
for i, result in enumerate(self.result):
if result.extracted_content:
state_description += (
f'\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}'
)
if result.error:
# only use last 300 characters of error
error = result.error[-self.max_error_length:]
state_description += f'\nError of action {i + 1}/{len(self.result)}: ...{error}'
if self.state.screenshot:
# Format message for vision model
return HumanMessage(
content=[
{'type': 'text', 'text': state_description},
{
'type': 'image_url',
'image_url': {'url': f'data:image/png;base64,{self.state.screenshot}'},
},
]
)
return HumanMessage(content=state_description)

View File

@@ -8,28 +8,30 @@ from dataclasses import dataclass
from typing import Type
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
from browser_use.controller.registry.views import ActionModel
from browser_use.agent.views import AgentOutput
@dataclass
class CustomAgentStepInfo:
step_number: int
max_steps: int
task: str
add_infos: str
memory: str
task_progress: str
class CustomAgentBrain(BaseModel):
"""Current state of the agent"""
prev_action_evaluation: str
memory: str
progress: str
import_contents: str
completed_contents: str
thought: str
summary: str
action: str
class CustomAgentOutput(BaseModel):
class CustomAgentOutput(AgentOutput):
"""Output model for agent
@dev note: this model is extended with custom actions in AgentService. You can also use some fields that are not in this model as provided by the linter, as long as they are registered in the DynamicActions model.

View File

@@ -74,11 +74,13 @@ async def test_browser_use_org():
async def test_browser_use_custom():
from browser_use.browser.context import BrowserContextWindowSize
from src.browser.custom_browser import CustomBrowser, BrowserConfig
from src.browser.custom_context import BrowserContext, BrowserContextConfig
from src.controller.custom_controller import CustomController
from browser_use.browser.context import BrowserContextWindowSize
from src.agent.custom_agent import CustomAgent
from src.agent.custom_prompts import CustomSystemPrompt
window_w, window_h = 1920, 1080
@@ -105,7 +107,6 @@ async def test_browser_use_custom():
)
)
controller = CustomController()
async with await browser.new_context(
config=BrowserContextConfig(
trace_path='./tmp/traces',
@@ -119,6 +120,7 @@ async def test_browser_use_custom():
llm=llm,
browser_context=browser_context,
controller=controller,
system_prompt_class=CustomSystemPrompt
)
history: AgentHistoryList = await agent.run(max_steps=10)