mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix issue #4184: '[LLM] Support LLM routing through notdiamond'
This commit is contained in:
parent
5c31fd9357
commit
4fb3b0dcd5
@ -47,3 +47,5 @@ class ConfigType(str, Enum):
|
||||
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
|
||||
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
|
||||
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
|
||||
LLM_PROVIDERS = 'LLM_PROVIDERS'
|
||||
LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED'
|
||||
|
||||
@ -2,7 +2,7 @@ import copy
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
@ -26,6 +26,7 @@ from openhands.core.message import Message
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.llm.debug_mixin import DebugMixin
|
||||
from openhands.llm.retry_mixin import RetryMixin
|
||||
from openhands.llm.llm_router import LLMRouter
|
||||
|
||||
__all__ = ['LLM']
|
||||
|
||||
@ -77,6 +78,11 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
||||
# - 'messages': list of messages
|
||||
# - 'response': response from the LLM
|
||||
|
||||
if self.config.llm_router_enabled:
|
||||
self.router = LLMRouter(config, metrics)
|
||||
else:
|
||||
self.router = None
|
||||
self.llm_completions: list[dict[str, Any]] = []
|
||||
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
@ -123,6 +129,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
@ -173,6 +180,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
'The messages list is empty. At least one message is required.'
|
||||
|
||||
)
|
||||
|
||||
# log the entire LLM prompt
|
||||
@ -211,6 +219,40 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
self._completion = wrapper
|
||||
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: List[Message],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, float]:
|
||||
"""Complete the given messages using the best selected model or the default model."""
|
||||
start_time = time.time()
|
||||
|
||||
if self.router:
|
||||
response, _ = self.router.complete(messages, **kwargs)
|
||||
else:
|
||||
response = self._completion(
|
||||
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
latency = time.time() - start_time
|
||||
return response.choices[0].message.content, latency
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: List[Message],
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Stream the response using the best selected model or the default model."""
|
||||
if self.router:
|
||||
yield from self.router.stream(messages, **kwargs)
|
||||
else:
|
||||
yield from self._completion(
|
||||
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
|
||||
stream=True,
|
||||
**kwargs
|
||||
)
|
||||
@property
|
||||
def completion(self):
|
||||
"""Decorator for the litellm completion function.
|
||||
|
||||
65
openhands/llm/llm_router.py
Normal file
65
openhands/llm/llm_router.py
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
import os
|
||||
from typing import List, Tuple, Any
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.core.message import Message
|
||||
from openhands.core.metrics import Metrics
|
||||
|
||||
class LLMRouter(LLM):
|
||||
"""LLMRouter class that selects the best LLM for a given query."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
):
|
||||
super().__init__(config, metrics)
|
||||
self.llm_providers: List[str] = config.llm_providers
|
||||
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY")
|
||||
if not self.notdiamond_api_key:
|
||||
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set")
|
||||
|
||||
from notdiamond import NotDiamond
|
||||
self.client = NotDiamond()
|
||||
|
||||
def _select_model(self, messages: List[Message]) -> Tuple[str, Any]:
|
||||
"""Select the best model for the given messages."""
|
||||
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
session_id, provider = self.client.chat.completions.model_select(
|
||||
messages=formatted_messages,
|
||||
model=self.llm_providers
|
||||
)
|
||||
return provider.model, session_id
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: List[Message],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, float]:
|
||||
"""Complete the given messages using the best selected model."""
|
||||
selected_model, session_id = self._select_model(messages)
|
||||
|
||||
# Create a new LLM instance with the selected model
|
||||
selected_config = LLMConfig(model=selected_model)
|
||||
selected_llm = LLM(config=selected_config, metrics=self.metrics)
|
||||
|
||||
# Use the selected LLM to complete the messages
|
||||
response, latency = selected_llm.complete(messages, **kwargs)
|
||||
|
||||
return response, latency
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: List[Message],
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Stream the response using the best selected model."""
|
||||
selected_model, session_id = self._select_model(messages)
|
||||
|
||||
# Create a new LLM instance with the selected model
|
||||
selected_config = LLMConfig(model=selected_model)
|
||||
selected_llm = LLM(config=selected_config, metrics=self.metrics)
|
||||
|
||||
# Use the selected LLM to stream the response
|
||||
yield from selected_llm.stream(messages, **kwargs)
|
||||
54
tests/unit/test_llm_router.py
Normal file
54
tests/unit/test_llm_router.py
Normal file
@ -0,0 +1,54 @@
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.message import Message
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_router import LLMRouter
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notdiamond():
|
||||
with patch('openhands.llm.llm_router.NotDiamond') as mock:
|
||||
yield mock
|
||||
|
||||
def test_llm_router_enabled(mock_notdiamond):
|
||||
config = LLMConfig(
|
||||
model="test-model",
|
||||
llm_router_enabled=True,
|
||||
llm_providers=["model1", "model2"]
|
||||
)
|
||||
llm = LLM(config)
|
||||
|
||||
assert isinstance(llm.router, LLMRouter)
|
||||
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
mock_response = Mock()
|
||||
mock_response.choices[0].message.content = "Hello, how can I help you?"
|
||||
llm.router.complete = Mock(return_value=(mock_response, 0.5))
|
||||
|
||||
response, latency = llm.complete(messages)
|
||||
|
||||
assert response == "Hello, how can I help you?"
|
||||
assert isinstance(latency, float)
|
||||
llm.router.complete.assert_called_once_with(messages)
|
||||
|
||||
def test_llm_router_disabled():
|
||||
config = LLMConfig(
|
||||
model="test-model",
|
||||
llm_router_enabled=False
|
||||
)
|
||||
llm = LLM(config)
|
||||
|
||||
assert llm.router is None
|
||||
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
with patch.object(llm, '_completion') as mock_completion:
|
||||
mock_response = Mock()
|
||||
mock_response.choices[0].message.content = "Hello, how can I help you?"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
response, latency = llm.complete(messages)
|
||||
|
||||
assert response == "Hello, how can I help you?"
|
||||
assert isinstance(latency, float)
|
||||
mock_completion.assert_called_once()
|
||||
Loading…
x
Reference in New Issue
Block a user