Fix issue #4184: '[LLM] Support LLM routing through notdiamond'

This commit is contained in:
openhands 2024-10-03 11:44:20 +00:00
parent 5c31fd9357
commit 4fb3b0dcd5
4 changed files with 164 additions and 1 deletions

View File

@ -47,3 +47,5 @@ class ConfigType(str, Enum):
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
LLM_PROVIDERS = 'LLM_PROVIDERS'
LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED'

View File

@ -2,7 +2,7 @@ import copy
import time
import warnings
from functools import partial
from typing import Any
from typing import Any, List, Tuple
from openhands.core.config import LLMConfig
@ -26,6 +26,7 @@ from openhands.core.message import Message
from openhands.core.metrics import Metrics
from openhands.llm.debug_mixin import DebugMixin
from openhands.llm.retry_mixin import RetryMixin
from openhands.llm.llm_router import LLMRouter
__all__ = ['LLM']
@ -77,6 +78,11 @@ class LLM(RetryMixin, DebugMixin):
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
if self.config.llm_router_enabled:
self.router = LLMRouter(config, metrics)
else:
self.router = None
self.llm_completions: list[dict[str, Any]] = []
# litellm actually uses base Exception here for unknown model
@ -123,6 +129,7 @@ class LLM(RetryMixin, DebugMixin):
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
@ -173,6 +180,7 @@ class LLM(RetryMixin, DebugMixin):
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'
)
# log the entire LLM prompt
@ -211,6 +219,40 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper
def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model or the default model."""
start_time = time.time()
if self.router:
response, _ = self.router.complete(messages, **kwargs)
else:
response = self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
**kwargs
)
latency = time.time() - start_time
return response.choices[0].message.content, latency
def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model or the default model."""
if self.router:
yield from self.router.stream(messages, **kwargs)
else:
yield from self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
stream=True,
**kwargs
)
@property
def completion(self):
"""Decorator for the litellm completion function.

View File

@ -0,0 +1,65 @@
import os
from typing import List, Tuple, Any
from openhands.core.config import LLMConfig
from openhands.llm.llm import LLM
from openhands.core.message import Message
from openhands.core.metrics import Metrics
class LLMRouter(LLM):
"""LLMRouter class that selects the best LLM for a given query."""
def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
):
super().__init__(config, metrics)
self.llm_providers: List[str] = config.llm_providers
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY")
if not self.notdiamond_api_key:
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set")
from notdiamond import NotDiamond
self.client = NotDiamond()
def _select_model(self, messages: List[Message]) -> Tuple[str, Any]:
"""Select the best model for the given messages."""
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
session_id, provider = self.client.chat.completions.model_select(
messages=formatted_messages,
model=self.llm_providers
)
return provider.model, session_id
def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model."""
selected_model, session_id = self._select_model(messages)
# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)
# Use the selected LLM to complete the messages
response, latency = selected_llm.complete(messages, **kwargs)
return response, latency
def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model."""
selected_model, session_id = self._select_model(messages)
# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)
# Use the selected LLM to stream the response
yield from selected_llm.stream(messages, **kwargs)

View File

@ -0,0 +1,54 @@
import pytest
from unittest.mock import Mock, patch
from openhands.core.config import LLMConfig
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.llm_router import LLMRouter
@pytest.fixture
def mock_notdiamond():
with patch('openhands.llm.llm_router.NotDiamond') as mock:
yield mock
def test_llm_router_enabled(mock_notdiamond):
config = LLMConfig(
model="test-model",
llm_router_enabled=True,
llm_providers=["model1", "model2"]
)
llm = LLM(config)
assert isinstance(llm.router, LLMRouter)
messages = [Message(role="user", content="Hello")]
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
llm.router.complete = Mock(return_value=(mock_response, 0.5))
response, latency = llm.complete(messages)
assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
llm.router.complete.assert_called_once_with(messages)
def test_llm_router_disabled():
config = LLMConfig(
model="test-model",
llm_router_enabled=False
)
llm = LLM(config)
assert llm.router is None
messages = [Message(role="user", content="Hello")]
with patch.object(llm, '_completion') as mock_completion:
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
mock_completion.return_value = mock_response
response, latency = llm.complete(messages)
assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
mock_completion.assert_called_once()