From 75ab5051ec2efb57d41ad87945b384dbc02b06dc Mon Sep 17 00:00:00 2001 From: vincent Date: Tue, 28 Jan 2025 20:38:29 +0800 Subject: [PATCH] fix deepseek-r1 ollama --- src/agent/custom_agent.py | 15 +++++++++------ src/agent/custom_massage_manager.py | 24 ++++++++++++------------ src/agent/custom_prompts.py | 2 +- src/utils/utils.py | 5 ++--- tests/test_browser_use.py | 16 ++++++++++------ 5 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index 355a8ff..81f33c8 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -242,17 +242,17 @@ class CustomAgent(Agent): logger.info(f"🧠 All Memory: \n{step_info.memory}") self._save_conversation(input_messages, model_output) if self.model_name != "deepseek-reasoner": - # remove pre-prev message - self.message_manager._remove_last_state_message() + # remove prev message + self.message_manager._remove_state_message_by_index(-1) except Exception as e: # model call failed, remove last state message from history - self.message_manager._remove_last_state_message() + self.message_manager._remove_state_message_by_index(-1) raise e - result: list[ActionResult] = await self.controller.multi_act( - model_output.action, self.browser_context - ) actions: list[ActionModel] = model_output.action + result: list[ActionResult] = await self.controller.multi_act( + actions, self.browser_context + ) if len(result) != len(actions): # I think something changes, such information should let LLM know for ri in range(len(result), len(actions)): @@ -261,6 +261,9 @@ class CustomAgent(Agent): error=f"{actions[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \ Something new appeared after action {actions[len(result) - 1].model_dump_json(exclude_unset=True)}", is_done=False)) + if len(actions) == 0: + # TODO: fix no action case + result = [ActionResult(is_done=True, extracted_content=step_info.memory, include_in_memory=True)] self._last_result = result self._last_actions = actions if len(result) > 0 and result[-1].is_done: diff --git a/src/agent/custom_massage_manager.py b/src/agent/custom_massage_manager.py index e6fb1b5..f39c999 100644 --- a/src/agent/custom_massage_manager.py +++ b/src/agent/custom_massage_manager.py @@ -70,18 +70,6 @@ class CustomMassageManager(MessageManager): while diff > 0 and len(self.history.messages) > min_message_len: self.history.remove_message(min_message_len) # alway remove the oldest message diff = self.history.total_tokens - self.max_input_tokens - - def _remove_state_message_by_index(self, remove_ind=-1) -> None: - """Remove last state message from history""" - i = 0 - remove_cnt = 0 - while len(self.history.messages) and i <= len(self.history.messages): - i += 1 - if isinstance(self.history.messages[-i].message, HumanMessage): - remove_cnt += 1 - if remove_cnt == abs(remove_ind): - self.history.remove_message(-i) - break def add_state_message( self, @@ -115,3 +103,15 @@ class CustomMassageManager(MessageManager): len(text) // self.estimated_characters_per_token ) # Rough estimate if no tokenizer available return tokens + + def _remove_state_message_by_index(self, remove_ind=-1) -> None: + """Remove last state message from history""" + i = len(self.history.messages) - 1 + remove_cnt = 0 + while i >= 0: + if isinstance(self.history.messages[i].message, HumanMessage): + remove_cnt += 1 + if remove_cnt == abs(remove_ind): + self.history.remove_message(i) + break + i -= 1 \ No newline at end of file diff --git a/src/agent/custom_prompts.py b/src/agent/custom_prompts.py index 08a9040..1e1df63 100644 --- a/src/agent/custom_prompts.py +++ b/src/agent/custom_prompts.py @@ -183,7 +183,7 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): state_description = f""" {step_info_description} -1. Task: {self.step_info.task} +1. Task: {self.step_info.task}. 2. Hints(Optional): {self.step_info.add_infos} 3. Memory: diff --git a/src/utils/utils.py b/src/utils/utils.py index dd5a57f..c4218cd 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -94,12 +94,11 @@ def get_llm_model(provider: str, **kwargs): else: base_url = kwargs.get("base_url") - if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"): + if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): return DeepSeekR1ChatOllama( - model=kwargs.get("model_name", "deepseek-r1:7b"), + model=kwargs.get("model_name", "deepseek-r1:14b"), temperature=kwargs.get("temperature", 0.0), num_ctx=kwargs.get("num_ctx", 32000), - num_predict=kwargs.get("num_predict", 1024), base_url=kwargs.get("base_url", base_url), ) else: diff --git a/tests/test_browser_use.py b/tests/test_browser_use.py index 5a40c32..c9d1129 100644 --- a/tests/test_browser_use.py +++ b/tests/test_browser_use.py @@ -32,10 +32,14 @@ async def test_browser_use_org(): # api_key=os.getenv("AZURE_OPENAI_API_KEY", ""), # ) + # llm = utils.get_llm_model( + # provider="deepseek", + # model_name="deepseek-chat", + # temperature=0.8 + # ) + llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-chat", - temperature=0.8 + provider="ollama", model_name="deepseek-r1:14b", temperature=0.5 ) window_w, window_h = 1920, 1080 @@ -152,9 +156,9 @@ async def test_browser_use_custom(): controller = CustomController() use_own_browser = True disable_security = True - use_vision = True # Set to False when using DeepSeek + use_vision = False # Set to False when using DeepSeek - max_actions_per_step = 10 + max_actions_per_step = 1 playwright = None browser = None browser_context = None @@ -189,7 +193,7 @@ async def test_browser_use_custom(): ) ) agent = CustomAgent( - task="go to google.com and type 'OpenAI' click search and give me the first url", + task="Search 'Nvidia' and give me the first url", add_infos="", # some hints for llm to complete the task llm=llm, browser=browser,