diff --git a/.env.example b/.env.example
index 7b53b7a..fe2c67c 100644
--- a/.env.example
+++ b/.env.example
@@ -11,6 +11,8 @@ AZURE_OPENAI_API_KEY=
DEEPSEEK_ENDPOINT=https://api.deepseek.com
DEEPSEEK_API_KEY=
+OLLAMA_ENDPOINT=http://localhost:11434
+
# Set to false to disable anonymized telemetry
ANONYMIZED_TELEMETRY=true
diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py
index fc69a13..77ba6c3 100644
--- a/src/agent/custom_agent.py
+++ b/src/agent/custom_agent.py
@@ -98,7 +98,7 @@ class CustomAgent(Agent):
register_done_callback=register_done_callback,
tool_calling_method=tool_calling_method
)
- if self.model_name == "deepseek-reasoner":
+ if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"):
# deepseek-reasoner does not support function calling
self.use_deepseek_r1 = True
# deepseek-reasoner only support 64000 context
@@ -191,6 +191,7 @@ class CustomAgent(Agent):
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
+ logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
else:
ai_message = self.llm.invoke(input_messages)
@@ -201,6 +202,7 @@ class CustomAgent(Agent):
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
+ logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
# cut the number of actions to max_actions_per_step
@@ -229,6 +231,9 @@ class CustomAgent(Agent):
self.update_step_info(model_output, step_info)
logger.info(f"🧠All Memory: \n{step_info.memory}")
self._save_conversation(input_messages, model_output)
+ # should we remove last state message? at least, deepseek-reasoner cannot remove
+ if self.model_name != "deepseek-reasoner":
+ self.message_manager._remove_last_state_message()
except Exception as e:
# model call failed, remove last state message from history
self.message_manager._remove_last_state_message()
@@ -253,7 +258,7 @@ class CustomAgent(Agent):
self.consecutive_failures = 0
except Exception as e:
- result = self._handle_step_error(e)
+ result = await self._handle_step_error(e)
self._last_result = result
finally:
diff --git a/src/agent/custom_prompts.py b/src/agent/custom_prompts.py
index c69461f..f42859e 100644
--- a/src/agent/custom_prompts.py
+++ b/src/agent/custom_prompts.py
@@ -26,12 +26,7 @@ class CustomSystemPrompt(SystemPrompt):
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
},
"action": [
- {
- "action_name": {
- // action-specific parameters
- }
- },
- // ... more actions in sequence
+ * actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}*
]
}
@@ -44,7 +39,6 @@ class CustomSystemPrompt(SystemPrompt):
{"click_element": {"index": 3}}
]
- Navigation and extraction: [
- {"open_new_tab": {}},
{"go_to_url": {"url": "https://example.com"}},
{"extract_page_content": {}}
]
@@ -127,7 +121,7 @@ class CustomSystemPrompt(SystemPrompt):
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
1. Analyze the provided webpage elements and structure
2. Plan a sequence of actions to accomplish the given task
- 3. Respond with valid JSON containing your action sequence and state assessment
+ 3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin.
Current date and time: {time_str}
@@ -200,15 +194,16 @@ class CustomAgentMessagePrompt(AgentMessagePrompt):
"""
if self.result:
+
for i, result in enumerate(self.result):
if result.include_in_memory:
if result.extracted_content:
- state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}"
+ state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}"
if result.error:
# only use last 300 characters of error
error = result.error[-self.max_error_length:]
state_description += (
- f"\nError of action {i + 1}/{len(self.result)}: ...{error}"
+ f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}"
)
if self.state.screenshot:
diff --git a/src/utils/llm.py b/src/utils/llm.py
index c38df72..c17c0e9 100644
--- a/src/utils/llm.py
+++ b/src/utils/llm.py
@@ -25,6 +25,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
+from langchain_ollama import ChatOllama
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
@@ -98,4 +99,38 @@ class DeepSeekR1ChatOpenAI(ChatOpenAI):
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+class DeepSeekR1ChatOllama(ChatOllama):
+
+ async def ainvoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ org_ai_message = await super().ainvoke(input=input)
+ org_content = org_ai_message.content
+ reasoning_content = org_content.split("")[0].replace("", "")
+ content = org_content.split("")[1]
+ if "**JSON Response:**" in content:
+ content = content.split("**JSON Response:**")[-1]
+ return AIMessage(content=content, reasoning_content=reasoning_content)
+
+ def invoke(
+ self,
+ input: LanguageModelInput,
+ config: Optional[RunnableConfig] = None,
+ *,
+ stop: Optional[list[str]] = None,
+ **kwargs: Any,
+ ) -> AIMessage:
+ org_ai_message = super().invoke(input=input)
+ org_content = org_ai_message.content
+ reasoning_content = org_content.split("")[0].replace("", "")
+ content = org_content.split("")[1]
+ if "**JSON Response:**" in content:
+ content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)
\ No newline at end of file
diff --git a/src/utils/utils.py b/src/utils/utils.py
index 18ce403..0cc537b 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -10,7 +10,7 @@ from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
import gradio as gr
-from .llm import DeepSeekR1ChatOpenAI
+from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
def get_llm_model(provider: str, **kwargs):
"""
@@ -89,12 +89,25 @@ def get_llm_model(provider: str, **kwargs):
google_api_key=api_key,
)
elif provider == "ollama":
- return ChatOllama(
- model=kwargs.get("model_name", "qwen2.5:7b"),
- temperature=kwargs.get("temperature", 0.0),
- num_ctx=kwargs.get("num_ctx", 32000),
- base_url=kwargs.get("base_url", "http://localhost:11434"),
- )
+ if not kwargs.get("base_url", ""):
+ base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
+ else:
+ base_url = kwargs.get("base_url")
+
+ if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
+ return DeepSeekR1ChatOllama(
+ model=kwargs.get("model_name", "deepseek-r1:7b"),
+ temperature=kwargs.get("temperature", 0.0),
+ num_ctx=kwargs.get("num_ctx", 32000),
+ base_url=kwargs.get("base_url", base_url),
+ )
+ else:
+ return ChatOllama(
+ model=kwargs.get("model_name", "qwen2.5:7b"),
+ temperature=kwargs.get("temperature", 0.0),
+ num_ctx=kwargs.get("num_ctx", 32000),
+ base_url=kwargs.get("base_url", base_url),
+ )
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
@@ -120,7 +133,7 @@ model_names = {
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
- "ollama": ["qwen2.5:7b", "llama2:7b"],
+ "ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
}
diff --git a/tests/test_browser_use.py b/tests/test_browser_use.py
index 925b81d..1921995 100644
--- a/tests/test_browser_use.py
+++ b/tests/test_browser_use.py
@@ -257,22 +257,26 @@ async def test_browser_use_custom_v2():
# temperature=0.8
# )
- llm = utils.get_llm_model(
- provider="deepseek",
- model_name="deepseek-chat",
- temperature=0.8
- )
+ # llm = utils.get_llm_model(
+ # provider="deepseek",
+ # model_name="deepseek-chat",
+ # temperature=0.8
+ # )
# llm = utils.get_llm_model(
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
# )
+
+ # llm = utils.get_llm_model(
+ # provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
+ # )
controller = CustomController()
use_own_browser = False
disable_security = True
use_vision = False # Set to False when using DeepSeek
- max_actions_per_step = 1
+ max_actions_per_step = 10
playwright = None
browser = None
browser_context = None
@@ -303,7 +307,7 @@ async def test_browser_use_custom_v2():
)
)
agent = CustomAgent(
- task="give me stock price of Nvidia and tesla",
+ task="go to google.com and type 'Nvidia' click search and give me the first url",
add_infos="", # some hints for llm to complete the task
llm=llm,
browser=browser,
diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py
index 2bf4751..8809b89 100644
--- a/tests/test_llm_api.py
+++ b/tests/test_llm_api.py
@@ -142,12 +142,21 @@ def test_ollama_model():
llm = ChatOllama(model="qwen2.5:7b")
ai_msg = llm.invoke("Sing a ballad of LangChain.")
print(ai_msg.content)
+
+def test_deepseek_r1_ollama_model():
+ from src.utils.llm import DeepSeekR1ChatOllama
+
+ llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
+ ai_msg = llm.invoke("how many r in strawberry?")
+ print(ai_msg.content)
+ pdb.set_trace()
if __name__ == '__main__':
# test_openai_model()
# test_gemini_model()
# test_azure_openai_model()
- test_deepseek_model()
+ # test_deepseek_model()
# test_ollama_model()
- # test_deepseek_r1_model()
+ test_deepseek_r1_model()
+ # test_deepseek_r1_ollama_model()
\ No newline at end of file
diff --git a/webui.py b/webui.py
index 5a1130d..f2035f3 100644
--- a/webui.py
+++ b/webui.py
@@ -658,7 +658,8 @@ def create_ui(config, theme_name="Ocean"):
interactive=True,
allow_custom_value=True, # Allow users to input custom model names
choices=["auto", "json_schema", "function_calling"],
- info="Tool Calls Funtion Name"
+ info="Tool Calls Funtion Name",
+ visible=False
)
with gr.TabItem("🔧 LLM Configuration", id=2):