修改为直接使用OpenAIChat

This commit is contained in:
M87monster
2025-04-12 22:21:47 +08:00
parent d711c85644
commit 61de4e8631
2 changed files with 2 additions and 92 deletions

View File

@@ -136,91 +136,3 @@ class DeepSeekR1ChatOllama(ChatOllama):
if "**JSON Response:**" in content:
content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)
class SiliconFlowChat(ChatOpenAI):
"""Wrapper for SiliconFlow Chat API, fully compatible with OpenAI-spec format."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Ensure the API client is initialized with SiliconFlow's endpoint and key
self.client = OpenAI(
api_key=kwargs.get("api_key"),
base_url=kwargs.get("base_url")
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AIMessage:
"""Async call SiliconFlow API."""
# Convert input messages into OpenAI-compatible format
message_history = []
for input_msg in input:
if isinstance(input_msg, SystemMessage):
message_history.append({"role": "system", "content": input_msg.content})
elif isinstance(input_msg, AIMessage):
message_history.append({"role": "assistant", "content": input_msg.content})
else: # HumanMessage or similar
message_history.append({"role": "user", "content": input_msg.content})
# Send request to SiliconFlow API (OpenAI-spec endpoint)
response = await self.client.chat.completions.create(
model=self.model_name,
messages=message_history,
stop=stop,
**kwargs,
)
# Extract the AI response (SiliconFlow's response must match OpenAI format)
if hasattr(response.choices[0].message, "reasoning_content"):
reasoning_content = response.choices[0].message.reasoning_content
else:
reasoning_content = None
content = response.choices[0].message.content
return AIMessage(content=content, reasoning_content=reasoning_content) # Return reasoning_content if needed
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AIMessage:
"""Sync call SiliconFlow API."""
# Same conversion as async version
message_history = []
for input_msg in input:
if isinstance(input_msg, SystemMessage):
message_history.append({"role": "system", "content": input_msg.content})
elif isinstance(input_msg, AIMessage):
message_history.append({"role": "assistant", "content": input_msg.content})
else:
message_history.append({"role": "user", "content": input_msg.content})
# Sync call
response = self.client.chat.completions.create(
model=self.model_name,
messages=message_history,
stop=stop,
**kwargs,
)
# Handle reasoning_content (if supported)
reasoning_content = None
if hasattr(response.choices[0].message, "reasoning_content"):
reasoning_content = response.choices[0].message.reasoning_content
return AIMessage(
content=response.choices[0].message.content,
reasoning_content=reasoning_content, # Only if SiliconFlow supports it
)

View File

@@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama,SiliconFlowChat
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
PROVIDER_DISPLAY_NAMES = {
"openai": "OpenAI",
@@ -177,13 +177,11 @@ def get_llm_model(provider: str, **kwargs):
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return SiliconFlowChat(
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
max_tokens=kwargs.get("max_tokens", 512),
frequency_penalty=kwargs.get("frequency_penalty", 0.5),
)
else:
raise ValueError(f"Unsupported provider: {provider}")