From 664dce757e477f06a46dd461f6bf317abe65e8d1 Mon Sep 17 00:00:00 2001 From: vincent Date: Mon, 27 Jan 2025 16:36:13 +0800 Subject: [PATCH] add deepseek-r1 ollama --- .env.example | 2 ++ src/agent/custom_agent.py | 9 +++++++-- src/agent/custom_prompts.py | 15 +++++---------- src/utils/llm.py | 35 +++++++++++++++++++++++++++++++++++ src/utils/utils.py | 29 +++++++++++++++++++++-------- tests/test_browser_use.py | 18 +++++++++++------- tests/test_llm_api.py | 13 +++++++++++-- webui.py | 3 ++- 8 files changed, 94 insertions(+), 30 deletions(-) diff --git a/.env.example b/.env.example index 7b53b7a..fe2c67c 100644 --- a/.env.example +++ b/.env.example @@ -11,6 +11,8 @@ AZURE_OPENAI_API_KEY= DEEPSEEK_ENDPOINT=https://api.deepseek.com DEEPSEEK_API_KEY= +OLLAMA_ENDPOINT=http://localhost:11434 + # Set to false to disable anonymized telemetry ANONYMIZED_TELEMETRY=true diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index fc69a13..77ba6c3 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -98,7 +98,7 @@ class CustomAgent(Agent): register_done_callback=register_done_callback, tool_calling_method=tool_calling_method ) - if self.model_name == "deepseek-reasoner": + if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"): # deepseek-reasoner does not support function calling self.use_deepseek_r1 = True # deepseek-reasoner only support 64000 context @@ -191,6 +191,7 @@ class CustomAgent(Agent): parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", "")) parsed: AgentOutput = self.AgentOutput(**parsed_json) if parsed is None: + logger.debug(ai_message.content) raise ValueError(f'Could not parse response.') else: ai_message = self.llm.invoke(input_messages) @@ -201,6 +202,7 @@ class CustomAgent(Agent): parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", "")) parsed: AgentOutput = self.AgentOutput(**parsed_json) if parsed is None: + logger.debug(ai_message.content) raise ValueError(f'Could not parse response.') # cut the number of actions to max_actions_per_step @@ -229,6 +231,9 @@ class CustomAgent(Agent): self.update_step_info(model_output, step_info) logger.info(f"🧠 All Memory: \n{step_info.memory}") self._save_conversation(input_messages, model_output) + # should we remove last state message? at least, deepseek-reasoner cannot remove + if self.model_name != "deepseek-reasoner": + self.message_manager._remove_last_state_message() except Exception as e: # model call failed, remove last state message from history self.message_manager._remove_last_state_message() @@ -253,7 +258,7 @@ class CustomAgent(Agent): self.consecutive_failures = 0 except Exception as e: - result = self._handle_step_error(e) + result = await self._handle_step_error(e) self._last_result = result finally: diff --git a/src/agent/custom_prompts.py b/src/agent/custom_prompts.py index c69461f..f42859e 100644 --- a/src/agent/custom_prompts.py +++ b/src/agent/custom_prompts.py @@ -26,12 +26,7 @@ class CustomSystemPrompt(SystemPrompt): "summary": "Please generate a brief natural language description for the operation in next actions based on your Thought." }, "action": [ - { - "action_name": { - // action-specific parameters - } - }, - // ... more actions in sequence + * actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}* ] } @@ -44,7 +39,6 @@ class CustomSystemPrompt(SystemPrompt): {"click_element": {"index": 3}} ] - Navigation and extraction: [ - {"open_new_tab": {}}, {"go_to_url": {"url": "https://example.com"}}, {"extract_page_content": {}} ] @@ -127,7 +121,7 @@ class CustomSystemPrompt(SystemPrompt): AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to: 1. Analyze the provided webpage elements and structure 2. Plan a sequence of actions to accomplish the given task - 3. Respond with valid JSON containing your action sequence and state assessment + 3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin. Current date and time: {time_str} @@ -200,15 +194,16 @@ class CustomAgentMessagePrompt(AgentMessagePrompt): """ if self.result: + for i, result in enumerate(self.result): if result.include_in_memory: if result.extracted_content: - state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}" + state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}" if result.error: # only use last 300 characters of error error = result.error[-self.max_error_length:] state_description += ( - f"\nError of action {i + 1}/{len(self.result)}: ...{error}" + f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}" ) if self.state.screenshot: diff --git a/src/utils/llm.py b/src/utils/llm.py index c38df72..c17c0e9 100644 --- a/src/utils/llm.py +++ b/src/utils/llm.py @@ -25,6 +25,7 @@ from langchain_core.outputs import ( LLMResult, RunInfo, ) +from langchain_ollama import ChatOllama from langchain_core.output_parsers.base import OutputParserLike from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool @@ -98,4 +99,38 @@ class DeepSeekR1ChatOpenAI(ChatOpenAI): reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content + return AIMessage(content=content, reasoning_content=reasoning_content) + +class DeepSeekR1ChatOllama(ChatOllama): + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = await super().ainvoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] + return AIMessage(content=content, reasoning_content=reasoning_content) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = super().invoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] return AIMessage(content=content, reasoning_content=reasoning_content) \ No newline at end of file diff --git a/src/utils/utils.py b/src/utils/utils.py index 18ce403..0cc537b 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -10,7 +10,7 @@ from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI import gradio as gr -from .llm import DeepSeekR1ChatOpenAI +from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama def get_llm_model(provider: str, **kwargs): """ @@ -89,12 +89,25 @@ def get_llm_model(provider: str, **kwargs): google_api_key=api_key, ) elif provider == "ollama": - return ChatOllama( - model=kwargs.get("model_name", "qwen2.5:7b"), - temperature=kwargs.get("temperature", 0.0), - num_ctx=kwargs.get("num_ctx", 32000), - base_url=kwargs.get("base_url", "http://localhost:11434"), - ) + if not kwargs.get("base_url", ""): + base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + else: + base_url = kwargs.get("base_url") + + if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"): + return DeepSeekR1ChatOllama( + model=kwargs.get("model_name", "deepseek-r1:7b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + base_url=kwargs.get("base_url", base_url), + ) + else: + return ChatOllama( + model=kwargs.get("model_name", "qwen2.5:7b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + base_url=kwargs.get("base_url", base_url), + ) elif provider == "azure_openai": if not kwargs.get("base_url", ""): base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") @@ -120,7 +133,7 @@ model_names = { "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], "deepseek": ["deepseek-chat", "deepseek-reasoner"], "gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ], - "ollama": ["qwen2.5:7b", "llama2:7b"], + "ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"], "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"] } diff --git a/tests/test_browser_use.py b/tests/test_browser_use.py index 925b81d..1921995 100644 --- a/tests/test_browser_use.py +++ b/tests/test_browser_use.py @@ -257,22 +257,26 @@ async def test_browser_use_custom_v2(): # temperature=0.8 # ) - llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-chat", - temperature=0.8 - ) + # llm = utils.get_llm_model( + # provider="deepseek", + # model_name="deepseek-chat", + # temperature=0.8 + # ) # llm = utils.get_llm_model( # provider="ollama", model_name="qwen2.5:7b", temperature=0.5 # ) + + # llm = utils.get_llm_model( + # provider="ollama", model_name="deepseek-r1:14b", temperature=0.5 + # ) controller = CustomController() use_own_browser = False disable_security = True use_vision = False # Set to False when using DeepSeek - max_actions_per_step = 1 + max_actions_per_step = 10 playwright = None browser = None browser_context = None @@ -303,7 +307,7 @@ async def test_browser_use_custom_v2(): ) ) agent = CustomAgent( - task="give me stock price of Nvidia and tesla", + task="go to google.com and type 'Nvidia' click search and give me the first url", add_infos="", # some hints for llm to complete the task llm=llm, browser=browser, diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 2bf4751..8809b89 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -142,12 +142,21 @@ def test_ollama_model(): llm = ChatOllama(model="qwen2.5:7b") ai_msg = llm.invoke("Sing a ballad of LangChain.") print(ai_msg.content) + +def test_deepseek_r1_ollama_model(): + from src.utils.llm import DeepSeekR1ChatOllama + + llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b") + ai_msg = llm.invoke("how many r in strawberry?") + print(ai_msg.content) + pdb.set_trace() if __name__ == '__main__': # test_openai_model() # test_gemini_model() # test_azure_openai_model() - test_deepseek_model() + # test_deepseek_model() # test_ollama_model() - # test_deepseek_r1_model() + test_deepseek_r1_model() + # test_deepseek_r1_ollama_model() \ No newline at end of file diff --git a/webui.py b/webui.py index 5a1130d..f2035f3 100644 --- a/webui.py +++ b/webui.py @@ -658,7 +658,8 @@ def create_ui(config, theme_name="Ocean"): interactive=True, allow_custom_value=True, # Allow users to input custom model names choices=["auto", "json_schema", "function_calling"], - info="Tool Calls Funtion Name" + info="Tool Calls Funtion Name", + visible=False ) with gr.TabItem("🔧 LLM Configuration", id=2):