From 61de4e8631bf5e66bcf6358ff70fd491f1599f91 Mon Sep 17 00:00:00 2001 From: M87monster <2772762669@qq.com> Date: Sat, 12 Apr 2025 22:21:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=B8=BA=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E4=BD=BF=E7=94=A8OpenAIChat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/utils/llm.py | 88 ---------------------------------------------- src/utils/utils.py | 6 ++-- 2 files changed, 2 insertions(+), 92 deletions(-) diff --git a/src/utils/llm.py b/src/utils/llm.py index afb9def..0b601ed 100644 --- a/src/utils/llm.py +++ b/src/utils/llm.py @@ -136,91 +136,3 @@ class DeepSeekR1ChatOllama(ChatOllama): if "**JSON Response:**" in content: content = content.split("**JSON Response:**")[-1] return AIMessage(content=content, reasoning_content=reasoning_content) - - -class SiliconFlowChat(ChatOpenAI): - """Wrapper for SiliconFlow Chat API, fully compatible with OpenAI-spec format.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - - # Ensure the API client is initialized with SiliconFlow's endpoint and key - self.client = OpenAI( - api_key=kwargs.get("api_key"), - base_url=kwargs.get("base_url") - ) - - async def ainvoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> AIMessage: - """Async call SiliconFlow API.""" - - # Convert input messages into OpenAI-compatible format - message_history = [] - for input_msg in input: - if isinstance(input_msg, SystemMessage): - message_history.append({"role": "system", "content": input_msg.content}) - elif isinstance(input_msg, AIMessage): - message_history.append({"role": "assistant", "content": input_msg.content}) - else: # HumanMessage or similar - message_history.append({"role": "user", "content": input_msg.content}) - - # Send request to SiliconFlow API (OpenAI-spec endpoint) - response = await self.client.chat.completions.create( - model=self.model_name, - messages=message_history, - stop=stop, - **kwargs, - ) - - # Extract the AI response (SiliconFlow's response must match OpenAI format) - if hasattr(response.choices[0].message, "reasoning_content"): - reasoning_content = response.choices[0].message.reasoning_content - else: - reasoning_content = None - - content = response.choices[0].message.content - return AIMessage(content=content, reasoning_content=reasoning_content) # Return reasoning_content if needed - - def invoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> AIMessage: - """Sync call SiliconFlow API.""" - - # Same conversion as async version - message_history = [] - for input_msg in input: - if isinstance(input_msg, SystemMessage): - message_history.append({"role": "system", "content": input_msg.content}) - elif isinstance(input_msg, AIMessage): - message_history.append({"role": "assistant", "content": input_msg.content}) - else: - message_history.append({"role": "user", "content": input_msg.content}) - - # Sync call - response = self.client.chat.completions.create( - model=self.model_name, - messages=message_history, - stop=stop, - **kwargs, - ) - - # Handle reasoning_content (if supported) - reasoning_content = None - if hasattr(response.choices[0].message, "reasoning_content"): - reasoning_content = response.choices[0].message.reasoning_content - - return AIMessage( - content=response.choices[0].message.content, - reasoning_content=reasoning_content, # Only if SiliconFlow supports it - ) diff --git a/src/utils/utils.py b/src/utils/utils.py index a6e346b..62fc8a8 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI -from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama,SiliconFlowChat +from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama PROVIDER_DISPLAY_NAMES = { "openai": "OpenAI", @@ -177,13 +177,11 @@ def get_llm_model(provider: str, **kwargs): base_url = os.getenv("SiliconFLOW_ENDPOINT", "") else: base_url = kwargs.get("base_url") - return SiliconFlowChat( + return ChatOpenAI( api_key=api_key, base_url=base_url, model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), temperature=kwargs.get("temperature", 0.0), - max_tokens=kwargs.get("max_tokens", 512), - frequency_penalty=kwargs.get("frequency_penalty", 0.5), ) else: raise ValueError(f"Unsupported provider: {provider}")