diff --git a/README.md b/README.md index 529a9df..7c8297a 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ playwright install - `--dark-mode`: Enables dark mode for the user interface. 3. **Access the WebUI:** Open your web browser and navigate to `http://127.0.0.1:7788`. 4. **Using Your Own Browser(Optional):** - - Set `CHROME_PATH` to the executable path of your browser and `CHROME_USER_DATA` to the user data directory of your browser. + - Set `CHROME_PATH` to the executable path of your browser and `CHROME_USER_DATA` to the user data directory of your browser. Leave `CHROME_USER_DATA` empty if you want to use local user data. - Windows ```env CHROME_PATH="C:\Program Files\Google\Chrome\Application\chrome.exe" @@ -118,7 +118,7 @@ playwright install - Mac ```env CHROME_PATH="/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" - CHROME_USER_DATA="~/Library/Application Support/Google/Chrome/Profile 1" + CHROME_USER_DATA="/Users/YourUsername/Library/Application Support/Google/Chrome" ``` - Close all Chrome windows - Open the WebUI in a non-Chrome browser, such as Firefox or Edge. This is important because the persistent browser context will use the Chrome data when running the agent. diff --git a/requirements.txt b/requirements.txt index 8fa4294..34e4b0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ browser-use==0.1.29 pyperclip==1.9.0 gradio==5.10.0 +json-repair diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index 77ba6c3..10be78d 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -8,10 +8,11 @@ import os import base64 import io import platform -from browser_use.agent.prompts import SystemPrompt +from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt from browser_use.agent.service import Agent from browser_use.agent.views import ( ActionResult, + ActionModel, AgentHistoryList, AgentOutput, AgentHistory, @@ -30,6 +31,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( BaseMessage, ) +from json_repair import repair_json from src.utils.agent_state import AgentState from .custom_massage_manager import CustomMassageManager @@ -52,6 +54,7 @@ class CustomAgent(Agent): max_failures: int = 5, retry_delay: int = 10, system_prompt_class: Type[SystemPrompt] = SystemPrompt, + agent_prompt_class: Type[AgentMessagePrompt] = AgentMessagePrompt, max_input_tokens: int = 128000, validate_output: bool = False, include_attributes: list[str] = [ @@ -98,7 +101,7 @@ class CustomAgent(Agent): register_done_callback=register_done_callback, tool_calling_method=tool_calling_method ) - if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"): + if self.model_name in ["deepseek-reasoner"] or "deepseek-r1" in self.model_name: # deepseek-reasoner does not support function calling self.use_deepseek_r1 = True # deepseek-reasoner only support 64000 context @@ -106,20 +109,23 @@ class CustomAgent(Agent): else: self.use_deepseek_r1 = False + # record last actions + self._last_actions = None # custom new info self.add_infos = add_infos # agent_state for Stop self.agent_state = agent_state + self.agent_prompt_class = agent_prompt_class self.message_manager = CustomMassageManager( llm=self.llm, task=self.task, action_descriptions=self.controller.registry.get_prompt_description(), system_prompt_class=self.system_prompt_class, + agent_prompt_class=agent_prompt_class, max_input_tokens=self.max_input_tokens, include_attributes=self.include_attributes, max_error_length=self.max_error_length, - max_actions_per_step=self.max_actions_per_step, - use_deepseek_r1=self.use_deepseek_r1 + max_actions_per_step=self.max_actions_per_step ) def _setup_action_models(self) -> None: @@ -178,38 +184,39 @@ class CustomAgent(Agent): @time_execution_async("--get_next_action") async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: """Get next action from LLM based on current state""" - if self.use_deepseek_r1: - merged_input_messages = self.message_manager.merge_successive_human_messages(input_messages) - ai_message = self.llm.invoke(merged_input_messages) - self.message_manager._add_message_with_tokens(ai_message) - logger.info(f"🤯 Start Deep Thinking: ") - logger.info(ai_message.reasoning_content) - logger.info(f"🤯 End Deep Thinking") - if isinstance(ai_message.content, list): - parsed_json = json.loads(ai_message.content[0].replace("```json", "").replace("```", "")) - else: - parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", "")) - parsed: AgentOutput = self.AgentOutput(**parsed_json) - if parsed is None: - logger.debug(ai_message.content) - raise ValueError(f'Could not parse response.') - else: - ai_message = self.llm.invoke(input_messages) - self.message_manager._add_message_with_tokens(ai_message) - if isinstance(ai_message.content, list): - parsed_json = json.loads(ai_message.content[0].replace("```json", "").replace("```", "")) - else: - parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", "")) - parsed: AgentOutput = self.AgentOutput(**parsed_json) - if parsed is None: - logger.debug(ai_message.content) - raise ValueError(f'Could not parse response.') + messages_to_process = ( + self.message_manager.merge_successive_human_messages(input_messages) + if self.use_deepseek_r1 + else input_messages + ) - # cut the number of actions to max_actions_per_step + ai_message = self.llm.invoke(messages_to_process) + self.message_manager._add_message_with_tokens(ai_message) + + if self.use_deepseek_r1: + logger.info("🤯 Start Deep Thinking: ") + logger.info(ai_message.reasoning_content) + logger.info("🤯 End Deep Thinking") + + if isinstance(ai_message.content, list): + ai_content = ai_message.content[0] + else: + ai_content = ai_message.content + + ai_content = ai_content.replace("```json", "").replace("```", "") + ai_content = repair_json(ai_content) + parsed_json = json.loads(ai_content) + parsed: AgentOutput = self.AgentOutput(**parsed_json) + + if parsed is None: + logger.debug(ai_message.content) + raise ValueError('Could not parse response.') + + # Limit actions to maximum allowed per step parsed.action = parsed.action[: self.max_actions_per_step] self._log_response(parsed) self.n_steps += 1 - + return parsed @time_execution_async("--step") @@ -222,7 +229,7 @@ class CustomAgent(Agent): try: state = await self.browser_context.get_state(use_vision=self.use_vision) - self.message_manager.add_state_message(state, self._last_result, step_info) + self.message_manager.add_state_message(state, self._last_actions, self._last_result, step_info) input_messages = self.message_manager.get_messages() try: model_output = await self.get_next_action(input_messages) @@ -231,27 +238,31 @@ class CustomAgent(Agent): self.update_step_info(model_output, step_info) logger.info(f"🧠 All Memory: \n{step_info.memory}") self._save_conversation(input_messages, model_output) - # should we remove last state message? at least, deepseek-reasoner cannot remove if self.model_name != "deepseek-reasoner": - self.message_manager._remove_last_state_message() + # remove prev message + self.message_manager._remove_state_message_by_index(-1) except Exception as e: # model call failed, remove last state message from history - self.message_manager._remove_last_state_message() + self.message_manager._remove_state_message_by_index(-1) raise e + actions: list[ActionModel] = model_output.action result: list[ActionResult] = await self.controller.multi_act( - model_output.action, self.browser_context + actions, self.browser_context ) - if len(result) != len(model_output.action): + if len(result) != len(actions): # I think something changes, such information should let LLM know - for ri in range(len(result), len(model_output.action)): + for ri in range(len(result), len(actions)): result.append(ActionResult(extracted_content=None, include_in_memory=True, - error=f"{model_output.action[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \ - Something new appeared after action {model_output.action[len(result) - 1].model_dump_json(exclude_unset=True)}", + error=f"{actions[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \ + Something new appeared after action {actions[len(result) - 1].model_dump_json(exclude_unset=True)}", is_done=False)) + if len(actions) == 0: + # TODO: fix no action case + result = [ActionResult(is_done=True, extracted_content=step_info.memory, include_in_memory=True)] self._last_result = result - + self._last_actions = actions if len(result) > 0 and result[-1].is_done: logger.info(f"📄 Result: {result[-1].extracted_content}") diff --git a/src/agent/custom_massage_manager.py b/src/agent/custom_massage_manager.py index 3a6bb32..f39c999 100644 --- a/src/agent/custom_massage_manager.py +++ b/src/agent/custom_massage_manager.py @@ -5,8 +5,8 @@ from typing import List, Optional, Type from browser_use.agent.message_manager.service import MessageManager from browser_use.agent.message_manager.views import MessageHistory -from browser_use.agent.prompts import SystemPrompt -from browser_use.agent.views import ActionResult, AgentStepInfo +from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt +from browser_use.agent.views import ActionResult, AgentStepInfo, ActionModel from browser_use.browser.views import BrowserState from langchain_core.language_models import BaseChatModel from langchain_anthropic import ChatAnthropic @@ -31,14 +31,14 @@ class CustomMassageManager(MessageManager): task: str, action_descriptions: str, system_prompt_class: Type[SystemPrompt], + agent_prompt_class: Type[AgentMessagePrompt], max_input_tokens: int = 128000, estimated_characters_per_token: int = 3, image_tokens: int = 800, include_attributes: list[str] = [], max_error_length: int = 400, max_actions_per_step: int = 10, - message_context: Optional[str] = None, - use_deepseek_r1: bool = False + message_context: Optional[str] = None ): super().__init__( llm=llm, @@ -53,8 +53,7 @@ class CustomMassageManager(MessageManager): max_actions_per_step=max_actions_per_step, message_context=message_context ) - self.tool_id = 1 - self.use_deepseek_r1 = use_deepseek_r1 + self.agent_prompt_class = agent_prompt_class # Custom: Move Task info to state_message self.history = MessageHistory() self._add_message_with_tokens(self.system_prompt) @@ -75,13 +74,15 @@ class CustomMassageManager(MessageManager): def add_state_message( self, state: BrowserState, + actions: Optional[List[ActionModel]] = None, result: Optional[List[ActionResult]] = None, step_info: Optional[AgentStepInfo] = None, ) -> None: """Add browser state as human message""" # otherwise add state message and result to next message (which will not stay in memory) - state_message = CustomAgentMessagePrompt( + state_message = self.agent_prompt_class( state, + actions, result, include_attributes=self.include_attributes, max_error_length=self.max_error_length, @@ -102,3 +103,15 @@ class CustomMassageManager(MessageManager): len(text) // self.estimated_characters_per_token ) # Rough estimate if no tokenizer available return tokens + + def _remove_state_message_by_index(self, remove_ind=-1) -> None: + """Remove last state message from history""" + i = len(self.history.messages) - 1 + remove_cnt = 0 + while i >= 0: + if isinstance(self.history.messages[i].message, HumanMessage): + remove_cnt += 1 + if remove_cnt == abs(remove_ind): + self.history.remove_message(i) + break + i -= 1 \ No newline at end of file diff --git a/src/agent/custom_prompts.py b/src/agent/custom_prompts.py index f42859e..1e1df63 100644 --- a/src/agent/custom_prompts.py +++ b/src/agent/custom_prompts.py @@ -2,7 +2,7 @@ import pdb from typing import List, Optional from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt -from browser_use.agent.views import ActionResult +from browser_use.agent.views import ActionResult, ActionModel from browser_use.browser.views import BrowserState from langchain_core.messages import HumanMessage, SystemMessage @@ -56,7 +56,7 @@ class CustomSystemPrompt(SystemPrompt): - Use scroll to find elements you are looking for 5. TASK COMPLETION: - - If you think all the requirements of user\'s instruction have been completed and no further operation is required, output the done action to terminate the operation process. + - If you think all the requirements of user\'s instruction have been completed and no further operation is required, output the **Done** action to terminate the operation process. - Don't hallucinate actions. - If the task requires specific information - make sure to include everything in the done function. This is what the user will see. - If you are running out of steps (current step), think about speeding it up, and ALWAYS use the done action as the last action. @@ -140,6 +140,7 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): def __init__( self, state: BrowserState, + actions: Optional[List[ActionModel]] = None, result: Optional[List[ActionResult]] = None, include_attributes: list[str] = [], max_error_length: int = 400, @@ -151,10 +152,11 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): max_error_length=max_error_length, step_info=step_info ) + self.actions = actions def get_user_message(self) -> HumanMessage: if self.step_info: - step_info_description = f'Current step: {self.step_info.step_number + 1}/{self.step_info.max_steps}' + step_info_description = f'Current step: {self.step_info.step_number}/{self.step_info.max_steps}\n' else: step_info_description = '' @@ -181,7 +183,7 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): state_description = f""" {step_info_description} -1. Task: {self.step_info.task} +1. Task: {self.step_info.task}. 2. Hints(Optional): {self.step_info.add_infos} 3. Memory: @@ -193,17 +195,20 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): {elements_text} """ - if self.result: - + if self.actions and self.result: + state_description += "\n **Previous Actions** \n" + state_description += f'Previous step: {self.step_info.step_number-1}/{self.step_info.max_steps} \n' for i, result in enumerate(self.result): + action = self.actions[i] + state_description += f"Previous action {i + 1}/{len(self.result)}: {action.model_dump_json(exclude_unset=True)}\n" if result.include_in_memory: if result.extracted_content: - state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}" + state_description += f"Result of previous action {i + 1}/{len(self.result)}: {result.extracted_content}\n" if result.error: # only use last 300 characters of error error = result.error[-self.max_error_length:] state_description += ( - f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}" + f"Error of previous action {i + 1}/{len(self.result)}: ...{error}\n" ) if self.state.screenshot: diff --git a/src/controller/custom_controller.py b/src/controller/custom_controller.py index a89bef0..4e2ca0f 100644 --- a/src/controller/custom_controller.py +++ b/src/controller/custom_controller.py @@ -3,7 +3,7 @@ from typing import Optional, Type from pydantic import BaseModel from browser_use.agent.views import ActionResult from browser_use.browser.context import BrowserContext -from browser_use.controller.service import Controller +from browser_use.controller.service import Controller, DoneAction class CustomController(Controller): diff --git a/src/utils/utils.py b/src/utils/utils.py index 7f524d1..73e9066 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -13,6 +13,14 @@ import gradio as gr from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama +PROVIDER_DISPLAY_NAMES = { + "openai": "OpenAI", + "azure_openai": "Azure OpenAI", + "anthropic": "Anthropic", + "deepseek": "DeepSeek", + "gemini": "Gemini" +} + def get_llm_model(provider: str, **kwargs): """ 获取LLM 模型 @@ -20,17 +28,19 @@ def get_llm_model(provider: str, **kwargs): :param kwargs: :return: """ + if provider not in ["ollama"]: + env_var = "GOOGLE_API_KEY" if provider == "gemini" else f"{provider.upper()}_API_KEY" + api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") + if not api_key: + handle_api_key_error(provider, env_var) + kwargs["api_key"] = api_key + if provider == "anthropic": if not kwargs.get("base_url", ""): base_url = "https://api.anthropic.com" else: base_url = kwargs.get("base_url") - if not kwargs.get("api_key", ""): - api_key = os.getenv("ANTHROPIC_API_KEY", "") - else: - api_key = kwargs.get("api_key") - return ChatAnthropic( model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"), temperature=kwargs.get("temperature", 0.0), @@ -59,11 +69,6 @@ def get_llm_model(provider: str, **kwargs): else: base_url = kwargs.get("base_url") - if not kwargs.get("api_key", ""): - api_key = os.getenv("OPENAI_API_KEY", "") - else: - api_key = kwargs.get("api_key") - return ChatOpenAI( model=kwargs.get("model_name", "gpt-4o"), temperature=kwargs.get("temperature", 0.0), @@ -76,11 +81,6 @@ def get_llm_model(provider: str, **kwargs): else: base_url = kwargs.get("base_url") - if not kwargs.get("api_key", ""): - api_key = os.getenv("DEEPSEEK_API_KEY", "") - else: - api_key = kwargs.get("api_key") - if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": return DeepSeekR1ChatOpenAI( model=kwargs.get("model_name", "deepseek-reasoner"), @@ -96,10 +96,6 @@ def get_llm_model(provider: str, **kwargs): api_key=api_key, ) elif provider == "gemini": - if not kwargs.get("api_key", ""): - api_key = os.getenv("GOOGLE_API_KEY", "") - else: - api_key = kwargs.get("api_key") return ChatGoogleGenerativeAI( model=kwargs.get("model_name", "gemini-2.0-flash-exp"), temperature=kwargs.get("temperature", 0.0), @@ -111,9 +107,9 @@ def get_llm_model(provider: str, **kwargs): else: base_url = kwargs.get("base_url") - if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"): + if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): return DeepSeekR1ChatOllama( - model=kwargs.get("model_name", "deepseek-r1:7b"), + model=kwargs.get("model_name", "deepseek-r1:14b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), base_url=kwargs.get("base_url", base_url), @@ -123,6 +119,7 @@ def get_llm_model(provider: str, **kwargs): model=kwargs.get("model_name", "qwen2.5:7b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), + num_predict=kwargs.get("num_predict", 1024), base_url=kwargs.get("base_url", base_url), ) elif provider == "azure_openai": @@ -130,10 +127,6 @@ def get_llm_model(provider: str, **kwargs): base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") else: base_url = kwargs.get("base_url") - if not kwargs.get("api_key", ""): - api_key = os.getenv("AZURE_OPENAI_API_KEY", "") - else: - api_key = kwargs.get("api_key") return AzureChatOpenAI( model=kwargs.get("model_name", "gpt-4o"), temperature=kwargs.get("temperature", 0.0), @@ -171,7 +164,17 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None): return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True) else: return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True) - + +def handle_api_key_error(provider: str, env_var: str): + """ + Handles the missing API key error by raising a gr.Error with a clear message. + """ + provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) + raise gr.Error( + f"💥 {provider_display} API key not found! 🔑 Please set the " + f"`{env_var}` environment variable or provide it in the UI." + ) + def encode_image(img_path): if not img_path: return None diff --git a/tests/test_browser_use.py b/tests/test_browser_use.py index 1921995..c9d1129 100644 --- a/tests/test_browser_use.py +++ b/tests/test_browser_use.py @@ -32,10 +32,14 @@ async def test_browser_use_org(): # api_key=os.getenv("AZURE_OPENAI_API_KEY", ""), # ) + # llm = utils.get_llm_model( + # provider="deepseek", + # model_name="deepseek-chat", + # temperature=0.8 + # ) + llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-chat", - temperature=0.8 + provider="ollama", model_name="deepseek-r1:14b", temperature=0.5 ) window_w, window_h = 1920, 1080 @@ -99,151 +103,29 @@ async def test_browser_use_custom(): from playwright.async_api import async_playwright from src.agent.custom_agent import CustomAgent - from src.agent.custom_prompts import CustomSystemPrompt + from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePrompt from src.browser.custom_browser import CustomBrowser from src.browser.custom_context import BrowserContextConfig from src.controller.custom_controller import CustomController window_w, window_h = 1920, 1080 - + # llm = utils.get_llm_model( - # provider="azure_openai", + # provider="openai", # model_name="gpt-4o", # temperature=0.8, - # base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""), - # api_key=os.getenv("AZURE_OPENAI_API_KEY", ""), + # base_url=os.getenv("OPENAI_ENDPOINT", ""), + # api_key=os.getenv("OPENAI_API_KEY", ""), # ) llm = utils.get_llm_model( - provider="gemini", - model_name="gemini-2.0-flash-exp", - temperature=1.0, - api_key=os.getenv("GOOGLE_API_KEY", "") + provider="azure_openai", + model_name="gpt-4o", + temperature=0.8, + base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""), + api_key=os.getenv("AZURE_OPENAI_API_KEY", ""), ) - # llm = utils.get_llm_model( - # provider="deepseek", - # model_name="deepseek-chat", - # temperature=0.8 - # ) - - # llm = utils.get_llm_model( - # provider="ollama", model_name="qwen2.5:7b", temperature=0.8 - # ) - - controller = CustomController() - use_own_browser = False - disable_security = True - use_vision = True # Set to False when using DeepSeek - tool_call_in_content = True # Set to True when using Ollama - max_actions_per_step = 1 - playwright = None - browser_context_ = None - try: - if use_own_browser: - playwright = await async_playwright().start() - chrome_exe = os.getenv("CHROME_PATH", "") - chrome_use_data = os.getenv("CHROME_USER_DATA", "") - browser_context_ = await playwright.chromium.launch_persistent_context( - user_data_dir=chrome_use_data, - executable_path=chrome_exe, - no_viewport=False, - headless=False, # 保持浏览器窗口可见 - user_agent=( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " - "(KHTML, like Gecko) Chrome/85.0.4183.102 Safari/537.36" - ), - java_script_enabled=True, - bypass_csp=disable_security, - ignore_https_errors=disable_security, - record_video_dir="./tmp/record_videos", - record_video_size={"width": window_w, "height": window_h}, - ) - else: - browser_context_ = None - - browser = CustomBrowser( - config=BrowserConfig( - headless=False, - disable_security=True, - extra_chromium_args=[f"--window-size={window_w},{window_h}"], - ) - ) - - async with await browser.new_context( - config=BrowserContextConfig( - trace_path="./tmp/result_processing", - save_recording_path="./tmp/record_videos", - no_viewport=False, - browser_window_size=BrowserContextWindowSize( - width=window_w, height=window_h - ), - ), - context=browser_context_, - ) as browser_context: - agent = CustomAgent( - task="go to google.com and type 'OpenAI' click search and give me the first url", - add_infos="", # some hints for llm to complete the task - llm=llm, - browser_context=browser_context, - controller=controller, - system_prompt_class=CustomSystemPrompt, - use_vision=use_vision, - tool_call_in_content=tool_call_in_content, - max_actions_per_step=max_actions_per_step - ) - history: AgentHistoryList = await agent.run(max_steps=10) - - print("Final Result:") - pprint(history.final_result(), indent=4) - - print("\nErrors:") - pprint(history.errors(), indent=4) - - # e.g. xPaths the model clicked on - print("\nModel Outputs:") - pprint(history.model_actions(), indent=4) - - print("\nThoughts:") - pprint(history.model_thoughts(), indent=4) - # close browser - except Exception: - import traceback - - traceback.print_exc() - finally: - # 显式关闭持久化上下文 - if browser_context_: - await browser_context_.close() - - # 关闭 Playwright 对象 - if playwright: - await playwright.stop() - - await browser.close() - - -async def test_browser_use_custom_v2(): - from browser_use.browser.context import BrowserContextWindowSize - from browser_use.browser.browser import BrowserConfig - from playwright.async_api import async_playwright - - from src.agent.custom_agent import CustomAgent - from src.agent.custom_prompts import CustomSystemPrompt - from src.browser.custom_browser import CustomBrowser - from src.browser.custom_context import BrowserContextConfig - from src.controller.custom_controller import CustomController - - window_w, window_h = 1920, 1080 - - # llm = utils.get_llm_model( - # provider="azure_openai", - # model_name="gpt-4o", - # temperature=0.8, - # base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""), - # api_key=os.getenv("AZURE_OPENAI_API_KEY", ""), - # ) - # llm = utils.get_llm_model( # provider="gemini", # model_name="gemini-2.0-flash-exp", @@ -272,20 +154,24 @@ async def test_browser_use_custom_v2(): # ) controller = CustomController() - use_own_browser = False + use_own_browser = True disable_security = True use_vision = False # Set to False when using DeepSeek - max_actions_per_step = 10 + max_actions_per_step = 1 playwright = None browser = None browser_context = None try: + extra_chromium_args = [f"--window-size={window_w},{window_h}"] if use_own_browser: chrome_path = os.getenv("CHROME_PATH", None) if chrome_path == "": chrome_path = None + chrome_user_data = os.getenv("CHROME_USER_DATA", None) + if chrome_user_data: + extra_chromium_args += [f"--user-data-dir={chrome_user_data}"] else: chrome_path = None browser = CustomBrowser( @@ -293,7 +179,7 @@ async def test_browser_use_custom_v2(): headless=False, disable_security=disable_security, chrome_instance_path=chrome_path, - extra_chromium_args=[f"--window-size={window_w},{window_h}"], + extra_chromium_args=extra_chromium_args, ) ) browser_context = await browser.new_context( @@ -307,17 +193,18 @@ async def test_browser_use_custom_v2(): ) ) agent = CustomAgent( - task="go to google.com and type 'Nvidia' click search and give me the first url", + task="Search 'Nvidia' and give me the first url", add_infos="", # some hints for llm to complete the task llm=llm, browser=browser, browser_context=browser_context, controller=controller, system_prompt_class=CustomSystemPrompt, + agent_prompt_class=CustomAgentMessagePrompt, use_vision=use_vision, max_actions_per_step=max_actions_per_step ) - history: AgentHistoryList = await agent.run(max_steps=10) + history: AgentHistoryList = await agent.run(max_steps=100) print("Final Result:") pprint(history.final_result(), indent=4) @@ -349,5 +236,4 @@ async def test_browser_use_custom_v2(): if __name__ == "__main__": # asyncio.run(test_browser_use_org()) - # asyncio.run(test_browser_use_custom()) - asyncio.run(test_browser_use_custom_v2()) + asyncio.run(test_browser_use_custom()) diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 8628b6a..45d5775 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -1,7 +1,10 @@ import os import pdb +from dataclasses import dataclass from dotenv import load_dotenv +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_ollama import ChatOllama load_dotenv() @@ -9,174 +12,115 @@ import sys sys.path.append(".") +@dataclass +class LLMConfig: + provider: str + model_name: str + temperature: float = 0.8 + base_url: str = None + api_key: str = None -def test_mistral_model(): - from langchain_core.messages import HumanMessage +def create_message_content(text, image_path=None): + content = [{"type": "text", "text": text}] + + if image_path: + from src.utils import utils + image_data = utils.encode_image(image_path) + content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_data}"} + }) + + return content + +def get_env_value(key, provider): + env_mappings = { + "openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"}, + "azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"}, + "gemini": {"api_key": "GOOGLE_API_KEY"}, + "deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"} + } + + if provider in env_mappings and key in env_mappings[provider]: + return os.getenv(env_mappings[provider][key], "") + return "" + +def test_llm(config, query, image_path=None, system_message=None): from src.utils import utils + # Special handling for Ollama-based models + if config.provider == "ollama": + if "deepseek-r1" in config.model_name: + from src.utils.llm import DeepSeekR1ChatOllama + llm = DeepSeekR1ChatOllama(model=config.model_name) + else: + llm = ChatOllama(model=config.model_name) + + ai_msg = llm.invoke(query) + print(ai_msg.content) + if "deepseek-r1" in config.model_name: + pdb.set_trace() + return + + # For other providers, use the standard configuration llm = utils.get_llm_model( - provider="mistral", - model_name="mistral-large-latest", - temperature=0.8, - base_url=os.getenv("MISTRAL_ENDPOINT", ""), - api_key=os.getenv("MISTRAL_API_KEY", "") + provider=config.provider, + model_name=config.model_name, + temperature=config.temperature, + base_url=config.base_url or get_env_value("base_url", config.provider), + api_key=config.api_key or get_env_value("api_key", config.provider) ) - message = HumanMessage( - content=[ - {"type": "text", "text": "who are you?"} - ] - ) - ai_msg = llm.invoke([message]) + + # Prepare messages for non-Ollama models + messages = [] + if system_message: + messages.append(SystemMessage(content=create_message_content(system_message))) + messages.append(HumanMessage(content=create_message_content(query, image_path))) + ai_msg = llm.invoke(messages) + + # Handle different response types + if hasattr(ai_msg, "reasoning_content"): + print(ai_msg.reasoning_content) print(ai_msg.content) + if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name: + print(llm.model_name) + pdb.set_trace() + def test_openai_model(): - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="openai", - model_name="gpt-4o", - temperature=0.8, - base_url=os.getenv("OPENAI_ENDPOINT", ""), - api_key=os.getenv("OPENAI_API_KEY", "") - ) - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) - + config = LLMConfig(provider="openai", model_name="gpt-4o") + test_llm(config, "Describe this image", "assets/examples/test.png") def test_gemini_model(): - # you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="gemini", - model_name="gemini-2.0-flash-exp", - temperature=0.8, - api_key=os.getenv("GOOGLE_API_KEY", "") - ) - - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) - + # Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart + config = LLMConfig(provider="gemini", model_name="gemini-2.0-flash-exp") + test_llm(config, "Describe this image", "assets/examples/test.png") def test_azure_openai_model(): - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="azure_openai", - model_name="gpt-4o", - temperature=0.8, - base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""), - api_key=os.getenv("AZURE_OPENAI_API_KEY", "") - ) - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) - + config = LLMConfig(provider="azure_openai", model_name="gpt-4o") + test_llm(config, "Describe this image", "assets/examples/test.png") def test_deepseek_model(): - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-chat", - temperature=0.8, - base_url=os.getenv("DEEPSEEK_ENDPOINT", ""), - api_key=os.getenv("DEEPSEEK_API_KEY", "") - ) - message = HumanMessage( - content=[ - {"type": "text", "text": "who are you?"} - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) + config = LLMConfig(provider="deepseek", model_name="deepseek-chat") + test_llm(config, "Who are you?") def test_deepseek_r1_model(): - from langchain_core.messages import HumanMessage, SystemMessage, AIMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-reasoner", - temperature=0.8, - base_url=os.getenv("DEEPSEEK_ENDPOINT", ""), - api_key=os.getenv("DEEPSEEK_API_KEY", "") - ) - messages = [] - sys_message = SystemMessage( - content=[{"type": "text", "text": "you are a helpful AI assistant"}] - ) - messages.append(sys_message) - user_message = HumanMessage( - content=[ - {"type": "text", "text": "9.11 and 9.8, which is greater?"} - ] - ) - messages.append(user_message) - ai_msg = llm.invoke(messages) - print(ai_msg.reasoning_content) - print(ai_msg.content) - print(llm.model_name) - pdb.set_trace() + config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner") + test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.") def test_ollama_model(): - from langchain_ollama import ChatOllama + config = LLMConfig(provider="ollama", model_name="qwen2.5:7b") + test_llm(config, "Sing a ballad of LangChain.") - llm = ChatOllama(model="qwen2.5:7b") - ai_msg = llm.invoke("Sing a ballad of LangChain.") - print(ai_msg.content) - def test_deepseek_r1_ollama_model(): - from src.utils.llm import DeepSeekR1ChatOllama + config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b") + test_llm(config, "How many 'r's are in the word 'strawberry'?") - llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b") - ai_msg = llm.invoke("how many r in strawberry?") - print(ai_msg.content) - pdb.set_trace() - - -if __name__ == '__main__': +if __name__ == "__main__": # test_openai_model() # test_gemini_model() # test_azure_openai_model() - # test_deepseek_model() + test_deepseek_model() # test_ollama_model() # test_deepseek_r1_model() # test_deepseek_r1_ollama_model() - test_mistral_model() \ No newline at end of file diff --git a/webui.py b/webui.py index f2035f3..f760aab 100644 --- a/webui.py +++ b/webui.py @@ -28,7 +28,7 @@ from src.utils.agent_state import AgentState from src.utils import utils from src.agent.custom_agent import CustomAgent from src.browser.custom_browser import CustomBrowser -from src.agent.custom_prompts import CustomSystemPrompt +from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePrompt from src.browser.custom_context import BrowserContextConfig, CustomBrowserContext from src.controller.custom_controller import CustomController from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft, Base @@ -184,6 +184,9 @@ async def run_browser_agent( gr.update(interactive=True) # Re-enable run button ) + except gr.Error: + raise + except Exception as e: import traceback traceback.print_exc() @@ -224,20 +227,24 @@ async def run_org_agent( # Clear any previous stop request _global_agent_state.clear_stop() + extra_chromium_args = [f"--window-size={window_w},{window_h}"] if use_own_browser: chrome_path = os.getenv("CHROME_PATH", None) if chrome_path == "": chrome_path = None + chrome_user_data = os.getenv("CHROME_USER_DATA", None) + if chrome_user_data: + extra_chromium_args += [f"--user-data-dir={chrome_user_data}"] else: chrome_path = None - + if _global_browser is None: _global_browser = Browser( config=BrowserConfig( headless=headless, disable_security=disable_security, chrome_instance_path=chrome_path, - extra_chromium_args=[f"--window-size={window_w},{window_h}"], + extra_chromium_args=extra_chromium_args, ) ) @@ -315,10 +322,14 @@ async def run_custom_agent( # Clear any previous stop request _global_agent_state.clear_stop() + extra_chromium_args = [f"--window-size={window_w},{window_h}"] if use_own_browser: chrome_path = os.getenv("CHROME_PATH", None) if chrome_path == "": chrome_path = None + chrome_user_data = os.getenv("CHROME_USER_DATA", None) + if chrome_user_data: + extra_chromium_args += [f"--user-data-dir={chrome_user_data}"] else: chrome_path = None @@ -331,7 +342,7 @@ async def run_custom_agent( headless=headless, disable_security=disable_security, chrome_instance_path=chrome_path, - extra_chromium_args=[f"--window-size={window_w},{window_h}"], + extra_chromium_args=extra_chromium_args, ) ) @@ -357,6 +368,7 @@ async def run_custom_agent( browser_context=_global_browser_context, controller=controller, system_prompt_class=CustomSystemPrompt, + agent_prompt_class=CustomAgentMessagePrompt, max_actions_per_step=max_actions_per_step, agent_state=_global_agent_state, tool_calling_method=tool_calling_method @@ -526,6 +538,12 @@ async def run_with_stream( try: result = await agent_task final_result, errors, model_actions, model_thoughts, latest_videos, trace, history_file, stop_button, run_button = result + except gr.Error: + final_result = "" + model_actions = "" + model_thoughts = "" + latest_videos = trace = history_file = None + except Exception as e: errors = f"Agent error: {str(e)}"