From 3740b93746777f9c490fd5883dbc4b438debe4d7 Mon Sep 17 00:00:00 2001 From: Richardson Gunde <152559661+richard-devbot@users.noreply.github.com> Date: Wed, 8 Jan 2025 19:53:10 +0530 Subject: [PATCH] Update custom_agent.py --- src/agent/custom_agent.py | 189 +++++++++++++++++++++----------------- 1 file changed, 103 insertions(+), 86 deletions(-) diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index 027a450..3bf5496 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -4,71 +4,45 @@ # @ProjectName: browser-use-webui # @FileName: custom_agent.py -import asyncio -import base64 -import io import json import logging -import os import pdb -import textwrap -import time -import uuid -from io import BytesIO -from pathlib import Path -from typing import Any, Optional, Type, TypeVar +import traceback +from typing import Optional, Type -from dotenv import load_dotenv -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import ( - BaseMessage, - SystemMessage, -) -from openai import RateLimitError -from PIL import Image, ImageDraw, ImageFont -from pydantic import BaseModel, ValidationError - -from browser_use.agent.message_manager.service import MessageManager -from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt +from browser_use.agent.prompts import SystemPrompt from browser_use.agent.service import Agent from browser_use.agent.views import ( ActionResult, - AgentError, - AgentHistory, AgentHistoryList, AgentOutput, - AgentStepInfo, ) from browser_use.browser.browser import Browser from browser_use.browser.context import BrowserContext -from browser_use.browser.views import BrowserState, BrowserStateHistory -from browser_use.controller.registry.views import ActionModel from browser_use.controller.service import Controller -from browser_use.dom.history_tree_processor.service import ( - DOMHistoryElement, - HistoryTreeProcessor, -) -from browser_use.telemetry.service import ProductTelemetry from browser_use.telemetry.views import ( AgentEndTelemetryEvent, AgentRunTelemetryEvent, AgentStepErrorTelemetryEvent, ) from browser_use.utils import time_execution_async +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + BaseMessage, +) -from .custom_views import CustomAgentOutput, CustomAgentStepInfo from .custom_massage_manager import CustomMassageManager +from .custom_views import CustomAgentOutput, CustomAgentStepInfo logger = logging.getLogger(__name__) class CustomAgent(Agent): - def __init__( self, task: str, llm: BaseChatModel, - add_infos: str = '', + add_infos: str = "", browser: Browser | None = None, browser_context: BrowserContext | None = None, controller: Controller = Controller(), @@ -80,23 +54,39 @@ class CustomAgent(Agent): max_input_tokens: int = 128000, validate_output: bool = False, include_attributes: list[str] = [ - 'title', - 'type', - 'name', - 'role', - 'tabindex', - 'aria-label', - 'placeholder', - 'value', - 'alt', - 'aria-expanded', + "title", + "type", + "name", + "role", + "tabindex", + "aria-label", + "placeholder", + "value", + "alt", + "aria-expanded", ], max_error_length: int = 400, max_actions_per_step: int = 10, + tool_call_in_content: bool = True, ): - super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path, - max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output, - include_attributes, max_error_length, max_actions_per_step) + super().__init__( + task=task, + llm=llm, + browser=browser, + browser_context=browser_context, + controller=controller, + use_vision=use_vision, + save_conversation_path=save_conversation_path, + max_failures=max_failures, + retry_delay=retry_delay, + system_prompt_class=system_prompt_class, + max_input_tokens=max_input_tokens, + validate_output=validate_output, + include_attributes=include_attributes, + max_error_length=max_error_length, + max_actions_per_step=max_actions_per_step, + tool_call_in_content=tool_call_in_content, + ) self.add_infos = add_infos self.message_manager = CustomMassageManager( llm=self.llm, @@ -107,6 +97,7 @@ class CustomAgent(Agent): include_attributes=self.include_attributes, max_error_length=self.max_error_length, max_actions_per_step=self.max_actions_per_step, + tool_call_in_content=tool_call_in_content, ) def _setup_action_models(self) -> None: @@ -118,24 +109,26 @@ class CustomAgent(Agent): def _log_response(self, response: CustomAgentOutput) -> None: """Log the model's response""" - if 'Success' in response.current_state.prev_action_evaluation: - emoji = 'āœ…' - elif 'Failed' in response.current_state.prev_action_evaluation: - emoji = 'āŒ' + if "Success" in response.current_state.prev_action_evaluation: + emoji = "āœ…" + elif "Failed" in response.current_state.prev_action_evaluation: + emoji = "āŒ" else: - emoji = '🤷' + emoji = "🤷" - logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}') - logger.info(f'🧠 New Memory: {response.current_state.important_contents}') - logger.info(f'ā³ Task Progress: {response.current_state.completed_contents}') - logger.info(f'šŸ¤” Thought: {response.current_state.thought}') - logger.info(f'šŸŽÆ Summary: {response.current_state.summary}') + logger.info(f"{emoji} Eval: {response.current_state.prev_action_evaluation}") + logger.info(f"🧠 New Memory: {response.current_state.important_contents}") + logger.info(f"ā³ Task Progress: {response.current_state.completed_contents}") + logger.info(f"šŸ¤” Thought: {response.current_state.thought}") + logger.info(f"šŸŽÆ Summary: {response.current_state.summary}") for i, action in enumerate(response.action): logger.info( - f'šŸ› ļø Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}' + f"šŸ› ļø Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}" ) - def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None): + def update_step_info( + self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None + ): """ update step info """ @@ -144,31 +137,54 @@ class CustomAgent(Agent): step_info.step_number += 1 important_contents = model_output.current_state.important_contents - if important_contents and 'None' not in important_contents and important_contents not in step_info.memory: - step_info.memory += important_contents + '\n' + if ( + important_contents + and "None" not in important_contents + and important_contents not in step_info.memory + ): + step_info.memory += important_contents + "\n" completed_contents = model_output.current_state.completed_contents - if completed_contents and 'None' not in completed_contents: + if completed_contents and "None" not in completed_contents: step_info.task_progress = completed_contents - @time_execution_async('--get_next_action') + @time_execution_async("--get_next_action") async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: """Get next action from LLM based on current state""" + try: + structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True) + response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore - ret = self.llm.invoke(input_messages) - parsed_json = json.loads(ret.content.replace('```json', '').replace("```", "")) - parsed: AgentOutput = self.AgentOutput(**parsed_json) - # cut the number of actions to max_actions_per_step - parsed.action = parsed.action[: self.max_actions_per_step] - self._log_response(parsed) - self.n_steps += 1 + parsed: AgentOutput = response['parsed'] + # cut the number of actions to max_actions_per_step + parsed.action = parsed.action[: self.max_actions_per_step] + self._log_response(parsed) + self.n_steps += 1 - return parsed + return parsed + except Exception as e: + # If something goes wrong, try to invoke the LLM again without structured output, + # and Manually parse the response. Temporarily solution for DeepSeek + ret = self.llm.invoke(input_messages) + if isinstance(ret.content, list): + parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", "")) + else: + parsed_json = json.loads(ret.content.replace("```json", "").replace("```", "")) + parsed: AgentOutput = self.AgentOutput(**parsed_json) + if parsed is None: + raise ValueError(f'Could not parse response.') - @time_execution_async('--step') + # cut the number of actions to max_actions_per_step + parsed.action = parsed.action[: self.max_actions_per_step] + self._log_response(parsed) + self.n_steps += 1 + + return parsed + + @time_execution_async("--step") async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: """Execute one step of the task""" - logger.info(f'\nšŸ“ Step {self.n_steps}') + logger.info(f"\nšŸ“ Step {self.n_steps}") state = None model_output = None result: list[ActionResult] = [] @@ -179,7 +195,7 @@ class CustomAgent(Agent): input_messages = self.message_manager.get_messages() model_output = await self.get_next_action(input_messages) self.update_step_info(model_output, step_info) - logger.info(f'🧠 All Memory: {step_info.memory}') + logger.info(f"🧠 All Memory: {step_info.memory}") self._save_conversation(input_messages, model_output) self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history self.message_manager.add_model_output(model_output) @@ -190,7 +206,7 @@ class CustomAgent(Agent): self._last_result = result if len(result) > 0 and result[-1].is_done: - logger.info(f'šŸ“„ Result: {result[-1].extracted_content}') + logger.info(f"šŸ“„ Result: {result[-1].extracted_content}") self.consecutive_failures = 0 @@ -215,7 +231,7 @@ class CustomAgent(Agent): async def run(self, max_steps: int = 100) -> AgentHistoryList: """Execute the task with maximum number of steps""" try: - logger.info(f'šŸš€ Starting task: {self.task}') + logger.info(f"šŸš€ Starting task: {self.task}") self.telemetry.capture( AgentRunTelemetryEvent( @@ -224,13 +240,14 @@ class CustomAgent(Agent): ) ) - step_info = CustomAgentStepInfo(task=self.task, - add_infos=self.add_infos, - step_number=1, - max_steps=max_steps, - memory='', - task_progress='' - ) + step_info = CustomAgentStepInfo( + task=self.task, + add_infos=self.add_infos, + step_number=1, + max_steps=max_steps, + memory="", + task_progress="", + ) for step in range(max_steps): if self._too_many_failures(): @@ -245,10 +262,10 @@ class CustomAgent(Agent): if not await self._validate_output(): continue - logger.info('āœ… Task completed successfully') + logger.info("āœ… Task completed successfully") break else: - logger.info('āŒ Failed to complete task in maximum steps') + logger.info("āŒ Failed to complete task in maximum steps") return self.history