mirror of
https://github.com/browser-use/web-ui.git
synced 2026-03-22 11:17:17 +08:00
Update custom_agent.py
This commit is contained in:
@@ -4,71 +4,45 @@
|
||||
# @ProjectName: browser-use-webui
|
||||
# @FileName: custom_agent.py
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pdb
|
||||
import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
import traceback
|
||||
from typing import Optional, Type
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from openai import RateLimitError
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from browser_use.agent.message_manager.service import MessageManager
|
||||
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt
|
||||
from browser_use.agent.prompts import SystemPrompt
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.agent.views import (
|
||||
ActionResult,
|
||||
AgentError,
|
||||
AgentHistory,
|
||||
AgentHistoryList,
|
||||
AgentOutput,
|
||||
AgentStepInfo,
|
||||
)
|
||||
from browser_use.browser.browser import Browser
|
||||
from browser_use.browser.context import BrowserContext
|
||||
from browser_use.browser.views import BrowserState, BrowserStateHistory
|
||||
from browser_use.controller.registry.views import ActionModel
|
||||
from browser_use.controller.service import Controller
|
||||
from browser_use.dom.history_tree_processor.service import (
|
||||
DOMHistoryElement,
|
||||
HistoryTreeProcessor,
|
||||
)
|
||||
from browser_use.telemetry.service import ProductTelemetry
|
||||
from browser_use.telemetry.views import (
|
||||
AgentEndTelemetryEvent,
|
||||
AgentRunTelemetryEvent,
|
||||
AgentStepErrorTelemetryEvent,
|
||||
)
|
||||
from browser_use.utils import time_execution_async
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
|
||||
from .custom_massage_manager import CustomMassageManager
|
||||
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomAgent(Agent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
llm: BaseChatModel,
|
||||
add_infos: str = '',
|
||||
add_infos: str = "",
|
||||
browser: Browser | None = None,
|
||||
browser_context: BrowserContext | None = None,
|
||||
controller: Controller = Controller(),
|
||||
@@ -80,23 +54,39 @@ class CustomAgent(Agent):
|
||||
max_input_tokens: int = 128000,
|
||||
validate_output: bool = False,
|
||||
include_attributes: list[str] = [
|
||||
'title',
|
||||
'type',
|
||||
'name',
|
||||
'role',
|
||||
'tabindex',
|
||||
'aria-label',
|
||||
'placeholder',
|
||||
'value',
|
||||
'alt',
|
||||
'aria-expanded',
|
||||
"title",
|
||||
"type",
|
||||
"name",
|
||||
"role",
|
||||
"tabindex",
|
||||
"aria-label",
|
||||
"placeholder",
|
||||
"value",
|
||||
"alt",
|
||||
"aria-expanded",
|
||||
],
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = True,
|
||||
):
|
||||
super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path,
|
||||
max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output,
|
||||
include_attributes, max_error_length, max_actions_per_step)
|
||||
super().__init__(
|
||||
task=task,
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
browser_context=browser_context,
|
||||
controller=controller,
|
||||
use_vision=use_vision,
|
||||
save_conversation_path=save_conversation_path,
|
||||
max_failures=max_failures,
|
||||
retry_delay=retry_delay,
|
||||
system_prompt_class=system_prompt_class,
|
||||
max_input_tokens=max_input_tokens,
|
||||
validate_output=validate_output,
|
||||
include_attributes=include_attributes,
|
||||
max_error_length=max_error_length,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
self.add_infos = add_infos
|
||||
self.message_manager = CustomMassageManager(
|
||||
llm=self.llm,
|
||||
@@ -107,6 +97,7 @@ class CustomAgent(Agent):
|
||||
include_attributes=self.include_attributes,
|
||||
max_error_length=self.max_error_length,
|
||||
max_actions_per_step=self.max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
|
||||
def _setup_action_models(self) -> None:
|
||||
@@ -118,24 +109,26 @@ class CustomAgent(Agent):
|
||||
|
||||
def _log_response(self, response: CustomAgentOutput) -> None:
|
||||
"""Log the model's response"""
|
||||
if 'Success' in response.current_state.prev_action_evaluation:
|
||||
emoji = '✅'
|
||||
elif 'Failed' in response.current_state.prev_action_evaluation:
|
||||
emoji = '❌'
|
||||
if "Success" in response.current_state.prev_action_evaluation:
|
||||
emoji = "✅"
|
||||
elif "Failed" in response.current_state.prev_action_evaluation:
|
||||
emoji = "❌"
|
||||
else:
|
||||
emoji = '🤷'
|
||||
emoji = "🤷"
|
||||
|
||||
logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}')
|
||||
logger.info(f'🧠 New Memory: {response.current_state.important_contents}')
|
||||
logger.info(f'⏳ Task Progress: {response.current_state.completed_contents}')
|
||||
logger.info(f'🤔 Thought: {response.current_state.thought}')
|
||||
logger.info(f'🎯 Summary: {response.current_state.summary}')
|
||||
logger.info(f"{emoji} Eval: {response.current_state.prev_action_evaluation}")
|
||||
logger.info(f"🧠 New Memory: {response.current_state.important_contents}")
|
||||
logger.info(f"⏳ Task Progress: {response.current_state.completed_contents}")
|
||||
logger.info(f"🤔 Thought: {response.current_state.thought}")
|
||||
logger.info(f"🎯 Summary: {response.current_state.summary}")
|
||||
for i, action in enumerate(response.action):
|
||||
logger.info(
|
||||
f'🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}'
|
||||
f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}"
|
||||
)
|
||||
|
||||
def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None):
|
||||
def update_step_info(
|
||||
self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None
|
||||
):
|
||||
"""
|
||||
update step info
|
||||
"""
|
||||
@@ -144,31 +137,54 @@ class CustomAgent(Agent):
|
||||
|
||||
step_info.step_number += 1
|
||||
important_contents = model_output.current_state.important_contents
|
||||
if important_contents and 'None' not in important_contents and important_contents not in step_info.memory:
|
||||
step_info.memory += important_contents + '\n'
|
||||
if (
|
||||
important_contents
|
||||
and "None" not in important_contents
|
||||
and important_contents not in step_info.memory
|
||||
):
|
||||
step_info.memory += important_contents + "\n"
|
||||
|
||||
completed_contents = model_output.current_state.completed_contents
|
||||
if completed_contents and 'None' not in completed_contents:
|
||||
if completed_contents and "None" not in completed_contents:
|
||||
step_info.task_progress = completed_contents
|
||||
|
||||
@time_execution_async('--get_next_action')
|
||||
@time_execution_async("--get_next_action")
|
||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
try:
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
ret = self.llm.invoke(input_messages)
|
||||
parsed_json = json.loads(ret.content.replace('```json', '').replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
parsed: AgentOutput = response['parsed']
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
return parsed
|
||||
except Exception as e:
|
||||
# If something goes wrong, try to invoke the LLM again without structured output,
|
||||
# and Manually parse the response. Temporarily solution for DeepSeek
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
raise ValueError(f'Could not parse response.')
|
||||
|
||||
@time_execution_async('--step')
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
|
||||
@time_execution_async("--step")
|
||||
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
|
||||
"""Execute one step of the task"""
|
||||
logger.info(f'\n📍 Step {self.n_steps}')
|
||||
logger.info(f"\n📍 Step {self.n_steps}")
|
||||
state = None
|
||||
model_output = None
|
||||
result: list[ActionResult] = []
|
||||
@@ -179,7 +195,7 @@ class CustomAgent(Agent):
|
||||
input_messages = self.message_manager.get_messages()
|
||||
model_output = await self.get_next_action(input_messages)
|
||||
self.update_step_info(model_output, step_info)
|
||||
logger.info(f'🧠 All Memory: {step_info.memory}')
|
||||
logger.info(f"🧠 All Memory: {step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
|
||||
self.message_manager.add_model_output(model_output)
|
||||
@@ -190,7 +206,7 @@ class CustomAgent(Agent):
|
||||
self._last_result = result
|
||||
|
||||
if len(result) > 0 and result[-1].is_done:
|
||||
logger.info(f'📄 Result: {result[-1].extracted_content}')
|
||||
logger.info(f"📄 Result: {result[-1].extracted_content}")
|
||||
|
||||
self.consecutive_failures = 0
|
||||
|
||||
@@ -215,7 +231,7 @@ class CustomAgent(Agent):
|
||||
async def run(self, max_steps: int = 100) -> AgentHistoryList:
|
||||
"""Execute the task with maximum number of steps"""
|
||||
try:
|
||||
logger.info(f'🚀 Starting task: {self.task}')
|
||||
logger.info(f"🚀 Starting task: {self.task}")
|
||||
|
||||
self.telemetry.capture(
|
||||
AgentRunTelemetryEvent(
|
||||
@@ -224,13 +240,14 @@ class CustomAgent(Agent):
|
||||
)
|
||||
)
|
||||
|
||||
step_info = CustomAgentStepInfo(task=self.task,
|
||||
add_infos=self.add_infos,
|
||||
step_number=1,
|
||||
max_steps=max_steps,
|
||||
memory='',
|
||||
task_progress=''
|
||||
)
|
||||
step_info = CustomAgentStepInfo(
|
||||
task=self.task,
|
||||
add_infos=self.add_infos,
|
||||
step_number=1,
|
||||
max_steps=max_steps,
|
||||
memory="",
|
||||
task_progress="",
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
if self._too_many_failures():
|
||||
@@ -245,10 +262,10 @@ class CustomAgent(Agent):
|
||||
if not await self._validate_output():
|
||||
continue
|
||||
|
||||
logger.info('✅ Task completed successfully')
|
||||
logger.info("✅ Task completed successfully")
|
||||
break
|
||||
else:
|
||||
logger.info('❌ Failed to complete task in maximum steps')
|
||||
logger.info("❌ Failed to complete task in maximum steps")
|
||||
|
||||
return self.history
|
||||
|
||||
|
||||
Reference in New Issue
Block a user