mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
remove llamaindex (#7151)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
924acb182b
commit
5128377baa
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@ -10,12 +10,6 @@ updates:
|
||||
pre-commit:
|
||||
patterns:
|
||||
- "pre-commit"
|
||||
llama:
|
||||
patterns:
|
||||
- "llama*"
|
||||
chromadb:
|
||||
patterns:
|
||||
- "chromadb"
|
||||
browsergym:
|
||||
patterns:
|
||||
- "browsergym*"
|
||||
|
||||
2
.github/workflows/dummy-agent-test.yml
vendored
2
.github/workflows/dummy-agent-test.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
python-version: '3.12'
|
||||
cache: 'poetry'
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --without evaluation,llama-index
|
||||
run: poetry install --without evaluation
|
||||
- name: Build Environment
|
||||
run: make build
|
||||
- name: Run tests
|
||||
|
||||
2
.github/workflows/integration-runner.yml
vendored
2
.github/workflows/integration-runner.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
Hi! I started running the integration tests on your PR. You will receive a comment with the results shortly.
|
||||
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --without evaluation,llama-index
|
||||
run: poetry install --without evaluation
|
||||
|
||||
- name: Configure config.toml for testing with Haiku
|
||||
env:
|
||||
|
||||
4
.github/workflows/py-unit-tests.yml
vendored
4
.github/workflows/py-unit-tests.yml
vendored
@ -44,11 +44,11 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --without evaluation,llama-index
|
||||
run: poetry install --without evaluation
|
||||
- name: Build Environment
|
||||
run: make build
|
||||
- name: Run Tests
|
||||
run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_long_term_memory.py
|
||||
run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
|
||||
31
Makefile
31
Makefile
@ -133,7 +133,7 @@ install-python-dependencies:
|
||||
export HNSWLIB_NO_NATIVE=1; \
|
||||
poetry run pip install chroma-hnswlib; \
|
||||
fi
|
||||
@poetry install --without llama-index
|
||||
@poetry install
|
||||
@if [ -f "/etc/manjaro-release" ]; then \
|
||||
echo "$(BLUE)Detected Manjaro Linux. Installing Playwright dependencies...$(RESET)"; \
|
||||
poetry run pip install playwright; \
|
||||
@ -265,35 +265,6 @@ setup-config-prompts:
|
||||
@read -p "Enter your LLM base URL [mostly used for local LLMs, leave blank if not needed - example: http://localhost:5001/v1/]: " llm_base_url; \
|
||||
if [[ ! -z "$$llm_base_url" ]]; then echo "base_url=\"$$llm_base_url\"" >> $(CONFIG_FILE).tmp; fi
|
||||
|
||||
@echo "Enter your LLM Embedding Model"; \
|
||||
echo "Choices are:"; \
|
||||
echo " - openai"; \
|
||||
echo " - azureopenai"; \
|
||||
echo " - Embeddings available only with OllamaEmbedding:"; \
|
||||
echo " - llama2"; \
|
||||
echo " - mxbai-embed-large"; \
|
||||
echo " - nomic-embed-text"; \
|
||||
echo " - all-minilm"; \
|
||||
echo " - stable-code"; \
|
||||
echo " - bge-m3"; \
|
||||
echo " - bge-large"; \
|
||||
echo " - paraphrase-multilingual"; \
|
||||
echo " - snowflake-arctic-embed"; \
|
||||
echo " - Leave blank to default to 'BAAI/bge-small-en-v1.5' via huggingface"; \
|
||||
read -p "> " llm_embedding_model; \
|
||||
echo "embedding_model=\"$$llm_embedding_model\"" >> $(CONFIG_FILE).tmp; \
|
||||
if [ "$$llm_embedding_model" = "llama2" ] || [ "$$llm_embedding_model" = "mxbai-embed-large" ] || [ "$$llm_embedding_model" = "nomic-embed-text" ] || [ "$$llm_embedding_model" = "all-minilm" ] || [ "$$llm_embedding_model" = "stable-code" ]; then \
|
||||
read -p "Enter the local model URL for the embedding model (will set llm.embedding_base_url): " llm_embedding_base_url; \
|
||||
echo "embedding_base_url=\"$$llm_embedding_base_url\"" >> $(CONFIG_FILE).tmp; \
|
||||
elif [ "$$llm_embedding_model" = "azureopenai" ]; then \
|
||||
read -p "Enter the Azure endpoint URL (will overwrite llm.base_url): " llm_base_url; \
|
||||
echo "base_url=\"$$llm_base_url\"" >> $(CONFIG_FILE).tmp; \
|
||||
read -p "Enter the Azure LLM Embedding Deployment Name: " llm_embedding_deployment_name; \
|
||||
echo "embedding_deployment_name=\"$$llm_embedding_deployment_name\"" >> $(CONFIG_FILE).tmp; \
|
||||
read -p "Enter the Azure API Version: " llm_api_version; \
|
||||
echo "api_version=\"$$llm_api_version\"" >> $(CONFIG_FILE).tmp; \
|
||||
fi
|
||||
|
||||
|
||||
# Develop in container
|
||||
docker-dev:
|
||||
|
||||
@ -132,15 +132,6 @@ api_key = ""
|
||||
# Custom LLM provider
|
||||
#custom_llm_provider = ""
|
||||
|
||||
# Embedding API base URL
|
||||
#embedding_base_url = ""
|
||||
|
||||
# Embedding deployment name
|
||||
#embedding_deployment_name = ""
|
||||
|
||||
# Embedding model to use
|
||||
embedding_model = "local"
|
||||
|
||||
# Maximum number of characters in an observation's content
|
||||
#max_message_chars = 10000
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ RUN apt-get update -y \
|
||||
|
||||
COPY ./pyproject.toml ./poetry.lock ./
|
||||
RUN touch README.md
|
||||
RUN export POETRY_CACHE_DIR && poetry install --without evaluation,llama-index --no-root && rm -rf $POETRY_CACHE_DIR
|
||||
RUN export POETRY_CACHE_DIR && poetry install --without evaluation --no-root && rm -rf $POETRY_CACHE_DIR
|
||||
|
||||
FROM python:3.12.3-slim AS openhands-app
|
||||
|
||||
|
||||
@ -197,21 +197,6 @@ For development setups, you can also define custom named LLM configurations. See
|
||||
- Default: `""`
|
||||
- Description: Custom LLM provider
|
||||
|
||||
### Embeddings
|
||||
- `embedding_base_url`
|
||||
- Type: `str`
|
||||
- Default: `""`
|
||||
- Description: Embedding API base URL
|
||||
|
||||
- `embedding_deployment_name`
|
||||
- Type: `str`
|
||||
- Default: `""`
|
||||
- Description: Embedding deployment name
|
||||
|
||||
- `embedding_model`
|
||||
- Type: `str`
|
||||
- Default: `"local"`
|
||||
- Description: Embedding model to use
|
||||
|
||||
### Message Handling
|
||||
- `max_message_chars`
|
||||
@ -302,16 +287,6 @@ The agent configuration options are defined in the `[agent]` and `[agent.<agent_
|
||||
- Default: `""`
|
||||
- Description: Name of the micro agent to use for this agent
|
||||
|
||||
### Memory Configuration
|
||||
- `memory_enabled`
|
||||
- Type: `bool`
|
||||
- Default: `false`
|
||||
- Description: Whether long-term memory (embeddings) is enabled
|
||||
|
||||
- `memory_max_threads`
|
||||
- Type: `int`
|
||||
- Default: `3`
|
||||
- Description: The maximum number of threads indexing at the same time for embeddings
|
||||
|
||||
### LLM Configuration
|
||||
- `llm_config`
|
||||
|
||||
@ -31,17 +31,11 @@ You will need your ChatGPT deployment name which can be found on the deployments
|
||||
- `Base URL` to your Azure API Base URL (e.g. `https://example-endpoint.openai.azure.com`)
|
||||
- `API Key` to your Azure API key
|
||||
|
||||
## Embeddings
|
||||
|
||||
OpenHands uses llama-index for embeddings. You can find their documentation on Azure [here](https://docs.llamaindex.ai/en/stable/api_reference/embeddings/azure_openai/).
|
||||
|
||||
### Azure OpenAI Configuration
|
||||
|
||||
When running OpenHands, set the following environment variables using `-e` in the
|
||||
When running OpenHands, set the following environment variable using `-e` in the
|
||||
[docker run command](/modules/usage/installation#start-the-app):
|
||||
|
||||
```
|
||||
LLM_EMBEDDING_MODEL="azureopenai"
|
||||
LLM_EMBEDDING_DEPLOYMENT_NAME="<your-embedding-deployment-name>" # e.g. "TextEmbedding...<etc>"
|
||||
LLM_API_VERSION="<api-version>" # e.g. "2024-02-15-preview"
|
||||
```
|
||||
|
||||
@ -14,8 +14,6 @@ class AgentConfig(BaseModel):
|
||||
codeact_enable_browsing: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
|
||||
codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling.
|
||||
codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False.
|
||||
memory_enabled: Whether long-term memory (embeddings) is enabled.
|
||||
memory_max_threads: The maximum number of threads indexing at the same time for embeddings. (deprecated)
|
||||
llm_config: The name of the llm config to use. If specified, this will override global llm config.
|
||||
enable_prompt_extensions: Whether to use prompt extensions (e.g., microagents, inject runtime info). Default is True.
|
||||
disabled_microagents: A list of microagents to disable (by name, without .py extension, e.g. ["github", "lint"]). Default is None.
|
||||
@ -25,8 +23,6 @@ class AgentConfig(BaseModel):
|
||||
"""
|
||||
|
||||
llm_config: str | None = Field(default=None)
|
||||
memory_enabled: bool = Field(default=False)
|
||||
memory_max_threads: int = Field(default=3)
|
||||
codeact_enable_browsing: bool = Field(default=True)
|
||||
codeact_enable_llm_editor: bool = Field(default=False)
|
||||
codeact_enable_jupyter: bool = Field(default=True)
|
||||
@ -50,11 +46,10 @@ class AgentConfig(BaseModel):
|
||||
Example:
|
||||
Apply generic agent config with custom agent overrides, e.g.
|
||||
[agent]
|
||||
memory_enabled = false
|
||||
enable_prompt_extensions = true
|
||||
enable_prompt_extensions = false
|
||||
[agent.BrowsingAgent]
|
||||
memory_enabled = true
|
||||
results in memory_enabled being true for BrowsingAgent but false for others.
|
||||
enable_prompt_extensions = true
|
||||
results in prompt_extensions being true for BrowsingAgent but false for others.
|
||||
|
||||
Returns:
|
||||
dict[str, AgentConfig]: A mapping where the key "agent" corresponds to the default configuration
|
||||
|
||||
@ -15,11 +15,8 @@ class LLMConfig(BaseModel):
|
||||
Attributes:
|
||||
model: The model to use.
|
||||
api_key: The API key to use.
|
||||
base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
|
||||
base_url: The base URL for the API. This is necessary for local LLMs.
|
||||
api_version: The version of the API.
|
||||
embedding_model: The embedding model to use.
|
||||
embedding_base_url: The base URL for the embedding API.
|
||||
embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
|
||||
aws_access_key_id: The AWS access key ID.
|
||||
aws_secret_access_key: The AWS secret access key.
|
||||
aws_region_name: The AWS region name.
|
||||
@ -52,9 +49,6 @@ class LLMConfig(BaseModel):
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
base_url: str | None = Field(default=None)
|
||||
api_version: str | None = Field(default=None)
|
||||
embedding_model: str = Field(default='local')
|
||||
embedding_base_url: str | None = Field(default=None)
|
||||
embedding_deployment_name: str | None = Field(default=None)
|
||||
aws_access_key_id: SecretStr | None = Field(default=None)
|
||||
aws_secret_access_key: SecretStr | None = Field(default=None)
|
||||
aws_region_name: str | None = Field(default=None)
|
||||
|
||||
@ -282,8 +282,6 @@ def finalize_config(cfg: AppConfig):
|
||||
# make sure log_completions_folder is an absolute path
|
||||
for llm in cfg.llms.values():
|
||||
llm.log_completions_folder = os.path.abspath(llm.log_completions_folder)
|
||||
if llm.embedding_base_url is None:
|
||||
llm.embedding_base_url = llm.base_url
|
||||
|
||||
if cfg.sandbox.use_host_network and platform.system() == 'Darwin':
|
||||
logger.openhands_logger.warning(
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
- Short Term History
|
||||
- Memory Condenser
|
||||
- Long Term Memory
|
||||
|
||||
## Short Term History
|
||||
- Short term history filters the event stream and computes the messages that are injected into the context
|
||||
@ -17,7 +16,3 @@
|
||||
- Then it does the same for later chunks of events between user messages
|
||||
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
|
||||
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
|
||||
|
||||
## Long Term Memory
|
||||
- Long term memory component stores embeddings for events and prompts in a vector store
|
||||
- The agent can query it when it needs detailed information about a past event or to learn new actions
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from openhands.memory.condenser import Condenser
|
||||
from openhands.memory.long_term_memory import LongTermMemory
|
||||
|
||||
__all__ = ['LongTermMemory', 'Condenser']
|
||||
__all__ = ['Condenser']
|
||||
|
||||
@ -1,188 +0,0 @@
|
||||
import json
|
||||
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_to_memory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.utils.embeddings import (
|
||||
LLAMA_INDEX_AVAILABLE,
|
||||
EmbeddingsLoader,
|
||||
check_llama_index,
|
||||
)
|
||||
|
||||
# Conditional imports based on llama_index availability
|
||||
if LLAMA_INDEX_AVAILABLE:
|
||||
import chromadb
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.indices.vector_store.retrievers.retriever import (
|
||||
VectorIndexRetriever,
|
||||
)
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
"""Handles storing information for the agent to access later, using chromadb."""
|
||||
|
||||
event_stream: EventStream
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
agent_config: AgentConfig,
|
||||
event_stream: EventStream,
|
||||
):
|
||||
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
|
||||
|
||||
check_llama_index()
|
||||
|
||||
# initialize the chromadb client
|
||||
db = chromadb.PersistentClient(
|
||||
path=f'./cache/sessions/{event_stream.sid}/memory',
|
||||
# FIXME anonymized_telemetry=False,
|
||||
)
|
||||
self.collection = db.get_or_create_collection(name='memories')
|
||||
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
||||
|
||||
# embedding model
|
||||
embedding_strategy = llm_config.embedding_model
|
||||
self.embed_model = EmbeddingsLoader.get_embedding_model(
|
||||
embedding_strategy, llm_config
|
||||
)
|
||||
logger.debug(f'Using embedding model: {self.embed_model}')
|
||||
|
||||
# instantiate the index
|
||||
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
|
||||
self.thought_idx = 0
|
||||
|
||||
# initialize the event stream
|
||||
self.event_stream = event_stream
|
||||
|
||||
# max of threads to run the pipeline
|
||||
self.memory_max_threads = agent_config.memory_max_threads
|
||||
|
||||
def add_event(self, event: Event):
|
||||
"""Adds a new event to the long term memory with a unique id.
|
||||
|
||||
Parameters:
|
||||
- event: The new event to be added to memory
|
||||
"""
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
return
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
self.thought_idx += 1
|
||||
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
|
||||
self._add_document(document=doc)
|
||||
|
||||
def _add_document(self, document: 'Document'):
|
||||
"""Inserts a single document into the index."""
|
||||
self.index.insert_nodes([self._create_node(document)])
|
||||
|
||||
def _create_node(self, document: 'Document') -> 'TextNode':
|
||||
"""Create a TextNode from a Document instance."""
|
||||
return TextNode(
|
||||
text=document.text,
|
||||
doc_id=document.doc_id,
|
||||
extra_info=document.extra_info,
|
||||
)
|
||||
|
||||
def search(self, query: str, k: int = 10) -> list[str]:
|
||||
"""Searches through the current memory using VectorIndexRetriever.
|
||||
|
||||
Parameters:
|
||||
- query (str): A query to match search results to
|
||||
- k (int): Number of top results to return
|
||||
|
||||
Returns:
|
||||
- list[str]: List of top k results found in current memory
|
||||
"""
|
||||
retriever = VectorIndexRetriever(
|
||||
index=self.index,
|
||||
similarity_top_k=k,
|
||||
)
|
||||
results = retriever.retrieve(query)
|
||||
|
||||
for result in results:
|
||||
logger.debug(
|
||||
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
|
||||
)
|
||||
|
||||
return [r.get_text() for r in results]
|
||||
|
||||
def _events_to_docs(self) -> list['Document']:
|
||||
"""Convert all events from the EventStream to documents for batch insert into the index."""
|
||||
try:
|
||||
events = self.event_stream.get_events()
|
||||
except Exception as e:
|
||||
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
|
||||
return []
|
||||
|
||||
documents: list[Document] = []
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
self.thought_idx += 1
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
continue
|
||||
|
||||
if documents:
|
||||
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
|
||||
else:
|
||||
logger.debug('No valid documents found to insert into the index.')
|
||||
|
||||
return documents
|
||||
|
||||
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
|
||||
"""Create nodes from a list of documents."""
|
||||
return [self._create_node(doc) for doc in documents]
|
||||
@ -1,187 +0,0 @@
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
try:
|
||||
# check if those we need later are available using importlib
|
||||
if importlib.util.find_spec('chromadb') is None:
|
||||
raise ImportError(
|
||||
'chromadb is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
if (
|
||||
importlib.util.find_spec(
|
||||
'llama_index.core.indices.vector_store.retrievers.retriever'
|
||||
)
|
||||
is None
|
||||
or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
|
||||
is None
|
||||
):
|
||||
raise ImportError(
|
||||
'llama_index is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
LLAMA_INDEX_AVAILABLE = True
|
||||
|
||||
except ImportError:
|
||||
LLAMA_INDEX_AVAILABLE = False
|
||||
|
||||
# Define supported embedding models
|
||||
SUPPORTED_OLLAMA_EMBED_MODELS = [
|
||||
'llama2',
|
||||
'mxbai-embed-large',
|
||||
'nomic-embed-text',
|
||||
'all-minilm',
|
||||
'stable-code',
|
||||
'bge-m3',
|
||||
'bge-large',
|
||||
'paraphrase-multilingual',
|
||||
'snowflake-arctic-embed',
|
||||
]
|
||||
|
||||
|
||||
def check_llama_index():
|
||||
"""Utility function to check the availability of llama_index.
|
||||
|
||||
Raises:
|
||||
ImportError: If llama_index is not available.
|
||||
"""
|
||||
if not LLAMA_INDEX_AVAILABLE:
|
||||
raise ImportError(
|
||||
'llama_index and its dependencies are not installed. '
|
||||
'To use memory features, please run: poetry install --with llama-index.'
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsLoader:
|
||||
"""Loader for embedding model initialization."""
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
|
||||
"""Initialize and return the appropriate embedding model based on the strategy.
|
||||
|
||||
Parameters:
|
||||
- strategy: The embedding strategy to use.
|
||||
- llm_config: Configuration for the LLM.
|
||||
|
||||
Returns:
|
||||
- An instance of the selected embedding model or None.
|
||||
"""
|
||||
|
||||
if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
return OllamaEmbedding(
|
||||
model_name=strategy,
|
||||
base_url=llm_config.embedding_base_url,
|
||||
ollama_additional_kwargs={'mirostat': 0},
|
||||
)
|
||||
elif strategy == 'openai':
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key.get_secret_value()
|
||||
if llm_config.api_key
|
||||
else None,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
return AzureOpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
deployment_name=llm_config.embedding_deployment_name,
|
||||
api_key=llm_config.api_key.get_secret_value()
|
||||
if llm_config.api_key
|
||||
else None,
|
||||
azure_endpoint=llm_config.base_url,
|
||||
api_version=llm_config.api_version,
|
||||
)
|
||||
elif strategy == 'voyage':
|
||||
from llama_index.embeddings.voyageai import VoyageEmbedding
|
||||
|
||||
return VoyageEmbedding(
|
||||
model_name='voyage-code-3',
|
||||
)
|
||||
elif (strategy is not None) and (strategy.lower() == 'none'):
|
||||
# TODO: this works but is not elegant enough. The incentive is when
|
||||
# an agent using embeddings is not used, there is no reason we need to
|
||||
# initialize an embedding model
|
||||
return None
|
||||
else:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# initialize the local embedding model
|
||||
local_embed_model = HuggingFaceEmbedding(
|
||||
model_name='BAAI/bge-small-en-v1.5'
|
||||
)
|
||||
|
||||
# for local embeddings, we need torch
|
||||
import torch
|
||||
|
||||
# choose the best device
|
||||
# first determine what is available: CUDA, MPS, or CPU
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ['PYTORCH_FORCE_CPU'] = (
|
||||
'1' # try to force CPU to avoid errors
|
||||
)
|
||||
|
||||
# override CUDA availability
|
||||
torch.cuda.is_available = lambda: False
|
||||
|
||||
# disable MPS to avoid errors
|
||||
if device != 'mps' and hasattr(torch.backends, 'mps'):
|
||||
torch.backends.mps.is_available = lambda: False
|
||||
torch.backends.mps.is_built = False
|
||||
|
||||
# the device being used
|
||||
logger.debug(f'Using device for embeddings: {device}')
|
||||
|
||||
return local_embed_model
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Utility functions to run pipelines, split out for profiling
|
||||
# --------------------------------------------------------------------------
|
||||
def run_pipeline(
|
||||
embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run a pipeline embedding documents."""
|
||||
|
||||
# set up a pipeline with the transformations to make
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
embed_model,
|
||||
],
|
||||
)
|
||||
|
||||
# run the pipeline with num_workers
|
||||
nodes = pipeline.run(
|
||||
documents=documents, show_progress=True, num_workers=num_workers
|
||||
)
|
||||
return nodes
|
||||
|
||||
|
||||
def insert_batch_docs(
|
||||
index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run the document indexing in parallel."""
|
||||
results = Parallel(n_jobs=num_workers, backend='threading')(
|
||||
delayed(index.insert)(doc) for doc in documents
|
||||
)
|
||||
return results
|
||||
2053
poetry.lock
generated
2053
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -79,17 +79,6 @@ memory-profiler = "^0.61.0"
|
||||
daytona-sdk = "0.9.1"
|
||||
python-json-logger = "^3.2.1"
|
||||
|
||||
[tool.poetry.group.llama-index.dependencies]
|
||||
llama-index = "*"
|
||||
llama-index-vector-stores-chroma = "*"
|
||||
chromadb = "*"
|
||||
llama-index-embeddings-huggingface = "*"
|
||||
torch = "2.5.1"
|
||||
llama-index-embeddings-azure-openai = "*"
|
||||
llama-index-embeddings-ollama = "*"
|
||||
voyageai = "*"
|
||||
llama-index-embeddings-voyageai = "*"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.9.8"
|
||||
mypy = "1.15.0"
|
||||
|
||||
@ -57,8 +57,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
|
||||
monkeypatch.setenv('WORKSPACE_BASE', '/repos/openhands/workspace')
|
||||
monkeypatch.setenv('LLM_API_KEY', 'sk-proj-rgMV0...')
|
||||
monkeypatch.setenv('LLM_MODEL', 'gpt-4o')
|
||||
monkeypatch.setenv('AGENT_MEMORY_MAX_THREADS', '4')
|
||||
monkeypatch.setenv('AGENT_MEMORY_ENABLED', 'True')
|
||||
monkeypatch.setenv('DEFAULT_AGENT', 'CodeActAgent')
|
||||
monkeypatch.setenv('SANDBOX_TIMEOUT', '10')
|
||||
|
||||
@ -70,9 +68,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
|
||||
assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...'
|
||||
assert config.get_llm_config().model == 'gpt-4o'
|
||||
assert isinstance(config.get_agent_config(), AgentConfig)
|
||||
assert isinstance(config.get_agent_config().memory_max_threads, int)
|
||||
assert config.get_agent_config().memory_max_threads == 4
|
||||
assert config.get_agent_config().memory_enabled is True
|
||||
assert config.default_agent == 'CodeActAgent'
|
||||
assert config.sandbox.timeout == 10
|
||||
|
||||
@ -80,7 +75,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
|
||||
def test_load_from_old_style_env(monkeypatch, default_config):
|
||||
# Test loading configuration from old-style environment variables using monkeypatch
|
||||
monkeypatch.setenv('LLM_API_KEY', 'test-api-key')
|
||||
monkeypatch.setenv('AGENT_MEMORY_ENABLED', 'True')
|
||||
monkeypatch.setenv('DEFAULT_AGENT', 'BrowsingAgent')
|
||||
monkeypatch.setenv('WORKSPACE_BASE', '/opt/files/workspace')
|
||||
monkeypatch.setenv('SANDBOX_BASE_CONTAINER_IMAGE', 'custom_image')
|
||||
@ -88,7 +82,6 @@ def test_load_from_old_style_env(monkeypatch, default_config):
|
||||
load_from_env(default_config, os.environ)
|
||||
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key'
|
||||
assert default_config.get_agent_config().memory_enabled is True
|
||||
assert default_config.default_agent == 'BrowsingAgent'
|
||||
assert default_config.workspace_base == '/opt/files/workspace'
|
||||
assert default_config.workspace_mount_path is None # before finalize_config
|
||||
@ -110,11 +103,11 @@ model = "some-cheap-model"
|
||||
api_key = "cheap-model-api-key"
|
||||
|
||||
[agent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
|
||||
[agent.BrowsingAgent]
|
||||
llm_config = "cheap"
|
||||
memory_enabled = false
|
||||
enable_prompt_extensions = false
|
||||
|
||||
[sandbox]
|
||||
timeout = 1
|
||||
@ -131,14 +124,16 @@ default_agent = "TestAgent"
|
||||
assert default_config.default_agent == 'TestAgent'
|
||||
assert default_config.get_llm_config().model == 'test-model'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
|
||||
assert default_config.get_agent_config().memory_enabled is True
|
||||
assert default_config.get_agent_config().enable_prompt_extensions is True
|
||||
|
||||
# undefined agent config inherits default ones
|
||||
assert (
|
||||
default_config.get_llm_config_from_agent('CodeActAgent')
|
||||
== default_config.get_llm_config()
|
||||
)
|
||||
assert default_config.get_agent_config('CodeActAgent').memory_enabled is True
|
||||
assert (
|
||||
default_config.get_agent_config('CodeActAgent').enable_prompt_extensions is True
|
||||
)
|
||||
|
||||
# defined agent config overrides default ones
|
||||
assert default_config.get_llm_config_from_agent(
|
||||
@ -148,7 +143,10 @@ default_agent = "TestAgent"
|
||||
default_config.get_llm_config_from_agent('BrowsingAgent').model
|
||||
== 'some-cheap-model'
|
||||
)
|
||||
assert default_config.get_agent_config('BrowsingAgent').memory_enabled is False
|
||||
assert (
|
||||
default_config.get_agent_config('BrowsingAgent').enable_prompt_extensions
|
||||
is False
|
||||
)
|
||||
|
||||
assert default_config.workspace_base == '/opt/files2/workspace'
|
||||
assert default_config.sandbox.timeout == 1
|
||||
@ -438,7 +436,7 @@ def test_core_not_in_toml(default_config, temp_toml_file):
|
||||
model = "test-model"
|
||||
|
||||
[agent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
|
||||
[sandbox]
|
||||
timeout = 1
|
||||
@ -451,7 +449,7 @@ security_analyzer = "semgrep"
|
||||
|
||||
load_from_toml(default_config, temp_toml_file)
|
||||
assert default_config.get_llm_config().model == 'test-model'
|
||||
assert default_config.get_agent_config().memory_enabled is True
|
||||
assert default_config.get_agent_config().enable_prompt_extensions is True
|
||||
assert default_config.sandbox.base_container_image == 'custom_image'
|
||||
assert default_config.sandbox.user_id == 1001
|
||||
assert default_config.security.security_analyzer == 'semgrep'
|
||||
@ -476,7 +474,7 @@ invalid_field = "test"
|
||||
model = "gpt-4"
|
||||
|
||||
[agent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
|
||||
[sandbox]
|
||||
invalid_field_in_sandbox = "test"
|
||||
@ -553,15 +551,6 @@ def test_workspace_mount_rewrite(default_config, monkeypatch):
|
||||
assert default_config.workspace_mount_path == '/sandbox/project'
|
||||
|
||||
|
||||
def test_embedding_base_url_default(default_config):
|
||||
default_config.get_llm_config().base_url = 'https://api.exampleapi.com'
|
||||
finalize_config(default_config)
|
||||
assert (
|
||||
default_config.get_llm_config().embedding_base_url
|
||||
== 'https://api.exampleapi.com'
|
||||
)
|
||||
|
||||
|
||||
def test_cache_dir_creation(default_config, tmpdir):
|
||||
default_config.cache_dir = str(tmpdir.join('test_cache'))
|
||||
finalize_config(default_config)
|
||||
@ -845,7 +834,9 @@ def test_api_keys_repr_str():
|
||||
|
||||
# Test AgentConfig
|
||||
# No attrs in AgentConfig have 'key' or 'token' in their name
|
||||
agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4)
|
||||
agent_config = AgentConfig(
|
||||
enable_prompt_extensions=True, codeact_enable_browsing=False
|
||||
)
|
||||
for attr_name in AgentConfig.model_fields.keys():
|
||||
if not attr_name.startswith('__'):
|
||||
assert (
|
||||
@ -927,12 +918,10 @@ max_budget_per_task = 4.0
|
||||
[llm.gpt3]
|
||||
model="gpt-3.5-turbo"
|
||||
api_key="redacted"
|
||||
embedding_model="openai"
|
||||
|
||||
[llm.gpt4o]
|
||||
model="gpt-4o"
|
||||
api_key="redacted"
|
||||
embedding_model="openai"
|
||||
"""
|
||||
|
||||
with open(temp_toml_file, 'w') as f:
|
||||
@ -940,7 +929,6 @@ embedding_model="openai"
|
||||
|
||||
llm_config = get_llm_config_arg('gpt3', temp_toml_file)
|
||||
assert llm_config.model == 'gpt-3.5-turbo'
|
||||
assert llm_config.embedding_model == 'openai'
|
||||
|
||||
|
||||
def test_get_agent_configs(default_config, temp_toml_file):
|
||||
@ -950,10 +938,10 @@ max_iterations = 100
|
||||
max_budget_per_task = 4.0
|
||||
|
||||
[agent.CodeActAgent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
|
||||
[agent.BrowsingAgent]
|
||||
memory_max_threads = 10
|
||||
codeact_enable_jupyter = false
|
||||
"""
|
||||
|
||||
with open(temp_toml_file, 'w') as f:
|
||||
@ -962,9 +950,9 @@ memory_max_threads = 10
|
||||
load_from_toml(default_config, temp_toml_file)
|
||||
|
||||
codeact_config = default_config.get_agent_configs().get('CodeActAgent')
|
||||
assert codeact_config.memory_enabled is True
|
||||
assert codeact_config.enable_prompt_extensions is True
|
||||
browsing_config = default_config.get_agent_configs().get('BrowsingAgent')
|
||||
assert browsing_config.memory_max_threads == 10
|
||||
assert browsing_config.codeact_enable_jupyter is False
|
||||
|
||||
|
||||
def test_get_agent_config_arg(temp_toml_file):
|
||||
@ -974,26 +962,24 @@ max_iterations = 100
|
||||
max_budget_per_task = 4.0
|
||||
|
||||
[agent.CodeActAgent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = false
|
||||
codeact_enable_browsing = false
|
||||
|
||||
[agent.BrowsingAgent]
|
||||
memory_enabled = false
|
||||
enable_prompt_extensions = true
|
||||
memory_max_threads = 10
|
||||
codeact_enable_jupyter = false
|
||||
"""
|
||||
|
||||
with open(temp_toml_file, 'w') as f:
|
||||
f.write(temp_toml)
|
||||
|
||||
agent_config = get_agent_config_arg('CodeActAgent', temp_toml_file)
|
||||
assert agent_config.memory_enabled
|
||||
assert not agent_config.enable_prompt_extensions
|
||||
assert not agent_config.codeact_enable_browsing
|
||||
|
||||
agent_config2 = get_agent_config_arg('BrowsingAgent', temp_toml_file)
|
||||
assert not agent_config2.memory_enabled
|
||||
assert agent_config2.enable_prompt_extensions
|
||||
assert agent_config2.memory_max_threads == 10
|
||||
assert not agent_config2.codeact_enable_jupyter
|
||||
|
||||
|
||||
def test_agent_config_custom_group_name(temp_toml_file):
|
||||
@ -1002,10 +988,10 @@ def test_agent_config_custom_group_name(temp_toml_file):
|
||||
max_iterations = 99
|
||||
|
||||
[agent.group1]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
|
||||
[agent.group2]
|
||||
memory_enabled = false
|
||||
enable_prompt_extensions = false
|
||||
"""
|
||||
with open(temp_toml_file, 'w') as f:
|
||||
f.write(temp_toml)
|
||||
@ -1017,9 +1003,9 @@ memory_enabled = false
|
||||
# run_infer in evaluation can use `get_agent_config_arg` to load custom
|
||||
# agent configs with any group name (not just agent name)
|
||||
agent_config1 = get_agent_config_arg('group1', temp_toml_file)
|
||||
assert agent_config1.memory_enabled
|
||||
assert agent_config1.enable_prompt_extensions
|
||||
agent_config2 = get_agent_config_arg('group2', temp_toml_file)
|
||||
assert not agent_config2.memory_enabled
|
||||
assert not agent_config2.enable_prompt_extensions
|
||||
|
||||
|
||||
def test_agent_config_from_toml_section():
|
||||
@ -1028,11 +1014,10 @@ def test_agent_config_from_toml_section():
|
||||
|
||||
# Test with base config and custom configs
|
||||
agent_section = {
|
||||
'memory_enabled': True,
|
||||
'memory_max_threads': 5,
|
||||
'enable_prompt_extensions': True,
|
||||
'CustomAgent1': {'memory_enabled': False, 'codeact_enable_browsing': False},
|
||||
'CustomAgent2': {'memory_max_threads': 10, 'enable_prompt_extensions': False},
|
||||
'codeact_enable_browsing': True,
|
||||
'CustomAgent1': {'codeact_enable_browsing': False},
|
||||
'CustomAgent2': {'enable_prompt_extensions': False},
|
||||
'InvalidAgent': {
|
||||
'invalid_field': 'some_value' # This should be skipped but not affect others
|
||||
},
|
||||
@ -1043,20 +1028,16 @@ def test_agent_config_from_toml_section():
|
||||
|
||||
# Verify the base config was correctly parsed
|
||||
assert 'agent' in result
|
||||
assert result['agent'].memory_enabled is True
|
||||
assert result['agent'].memory_max_threads == 5
|
||||
assert result['agent'].enable_prompt_extensions is True
|
||||
assert result['agent'].codeact_enable_browsing is True
|
||||
|
||||
# Verify custom configs were correctly parsed and inherit from base
|
||||
assert 'CustomAgent1' in result
|
||||
assert result['CustomAgent1'].memory_enabled is False # Overridden
|
||||
assert result['CustomAgent1'].memory_max_threads == 5 # Inherited
|
||||
assert result['CustomAgent1'].codeact_enable_browsing is False # Overridden
|
||||
assert result['CustomAgent1'].enable_prompt_extensions is True # Inherited
|
||||
|
||||
assert 'CustomAgent2' in result
|
||||
assert result['CustomAgent2'].memory_enabled is True # Inherited
|
||||
assert result['CustomAgent2'].memory_max_threads == 10 # Overridden
|
||||
assert result['CustomAgent2'].codeact_enable_browsing is True # Inherited
|
||||
assert result['CustomAgent2'].enable_prompt_extensions is False # Overridden
|
||||
|
||||
# Verify the invalid config was skipped
|
||||
@ -1070,8 +1051,11 @@ def test_agent_config_from_toml_section_with_invalid_base():
|
||||
# Test with invalid base config but valid custom configs
|
||||
agent_section = {
|
||||
'invalid_field': 'some_value', # This should be ignored in base config
|
||||
'memory_max_threads': 'not_an_int', # This should cause validation error
|
||||
'CustomAgent': {'memory_enabled': True, 'memory_max_threads': 8},
|
||||
'codeact_enable_jupyter': 'not_a_bool', # This should cause validation error
|
||||
'CustomAgent': {
|
||||
'codeact_enable_browsing': False,
|
||||
'codeact_enable_jupyter': True,
|
||||
},
|
||||
}
|
||||
|
||||
# Parse the section
|
||||
@ -1079,10 +1063,10 @@ def test_agent_config_from_toml_section_with_invalid_base():
|
||||
|
||||
# Verify a default base config was created despite the invalid fields
|
||||
assert 'agent' in result
|
||||
assert result['agent'].memory_enabled is False # Default value
|
||||
assert result['agent'].memory_max_threads == 3 # Default value
|
||||
assert result['agent'].codeact_enable_browsing is True # Default value
|
||||
assert result['agent'].codeact_enable_jupyter is True # Default value
|
||||
|
||||
# Verify custom config was still processed correctly
|
||||
assert 'CustomAgent' in result
|
||||
assert result['CustomAgent'].memory_enabled is True
|
||||
assert result['CustomAgent'].memory_max_threads == 8
|
||||
assert result['CustomAgent'].codeact_enable_browsing is False
|
||||
assert result['CustomAgent'].codeact_enable_jupyter is True
|
||||
|
||||
@ -8,8 +8,9 @@ from openhands.core.config.utils import load_from_toml
|
||||
|
||||
|
||||
def test_extended_config_from_dict():
|
||||
"""
|
||||
Test that ExtendedConfig.from_dict successfully creates an instance
|
||||
"""Test that ExtendedConfig.from_dict successfully creates an instance.
|
||||
|
||||
This test verifies that the from_dict method correctly creates an instance
|
||||
from a dictionary containing arbitrary extra keys.
|
||||
"""
|
||||
data = {'foo': 'bar', 'baz': 123, 'flag': True}
|
||||
@ -24,9 +25,7 @@ def test_extended_config_from_dict():
|
||||
|
||||
|
||||
def test_extended_config_empty():
|
||||
"""
|
||||
Test that an empty ExtendedConfig can be created and accessed.
|
||||
"""
|
||||
"""Test that an empty ExtendedConfig can be created and accessed."""
|
||||
ext_cfg = ExtendedConfig.from_dict({})
|
||||
assert ext_cfg.root == {}
|
||||
|
||||
@ -36,9 +35,10 @@ def test_extended_config_empty():
|
||||
|
||||
|
||||
def test_extended_config_str_and_repr():
|
||||
"""
|
||||
Test that __str__ and __repr__ return the correct string representations
|
||||
of the ExtendedConfig instance.
|
||||
"""Test that __str__ and __repr__ return the correct string representations.
|
||||
|
||||
This test verifies that the string representations of the ExtendedConfig instance
|
||||
include the expected key/value pairs.
|
||||
"""
|
||||
data = {'alpha': 'test', 'beta': 42}
|
||||
ext_cfg = ExtendedConfig.from_dict(data)
|
||||
@ -54,9 +54,10 @@ def test_extended_config_str_and_repr():
|
||||
|
||||
|
||||
def test_extended_config_getitem_and_getattr():
|
||||
"""
|
||||
Test that __getitem__ and __getattr__ can be used to access values
|
||||
in the ExtendedConfig instance.
|
||||
"""Test that __getitem__ and __getattr__ can be used to access values.
|
||||
|
||||
This test verifies that values in the ExtendedConfig instance can be accessed
|
||||
both via attribute access and dictionary-style access.
|
||||
"""
|
||||
data = {'key1': 'value1', 'key2': 2}
|
||||
ext_cfg = ExtendedConfig.from_dict(data)
|
||||
@ -68,9 +69,7 @@ def test_extended_config_getitem_and_getattr():
|
||||
|
||||
|
||||
def test_extended_config_invalid_key():
|
||||
"""
|
||||
Test that accessing a non-existent key via attribute access raises AttributeError.
|
||||
"""
|
||||
"""Test that accessing a non-existent key via attribute access raises AttributeError."""
|
||||
data = {'existing': 'yes'}
|
||||
ext_cfg = ExtendedConfig.from_dict(data)
|
||||
|
||||
@ -82,9 +81,10 @@ def test_extended_config_invalid_key():
|
||||
|
||||
|
||||
def test_app_config_extended_from_toml(tmp_path: os.PathLike) -> None:
|
||||
"""
|
||||
Test that the [extended] section in a TOML file is correctly loaded into
|
||||
AppConfig.extended and that it accepts arbitrary keys.
|
||||
"""Test that the [extended] section in a TOML file is correctly loaded.
|
||||
|
||||
This test verifies that the [extended] section is loaded into AppConfig.extended
|
||||
and that it accepts arbitrary keys.
|
||||
"""
|
||||
# Create a temporary TOML file with multiple sections including [extended]
|
||||
config_content = """
|
||||
@ -101,7 +101,7 @@ custom2 = 42
|
||||
llm = "overridden" # even a key like 'llm' is accepted in extended
|
||||
|
||||
[agent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
"""
|
||||
config_file = tmp_path / 'config.toml'
|
||||
config_file.write_text(config_content)
|
||||
@ -118,8 +118,9 @@ memory_enabled = true
|
||||
|
||||
|
||||
def test_app_config_extended_default(tmp_path: os.PathLike) -> None:
|
||||
"""
|
||||
Test that if there is no [extended] section in the TOML file,
|
||||
"""Test default behavior when no [extended] section exists.
|
||||
|
||||
This test verifies that if there is no [extended] section in the TOML file,
|
||||
AppConfig.extended remains its default (empty) ExtendedConfig.
|
||||
"""
|
||||
config_content = """
|
||||
@ -131,7 +132,7 @@ model = "test-model"
|
||||
api_key = "toml-api-key"
|
||||
|
||||
[agent]
|
||||
memory_enabled = true
|
||||
enable_prompt_extensions = true
|
||||
"""
|
||||
config_file = tmp_path / 'config.toml'
|
||||
config_file.write_text(config_content)
|
||||
@ -144,8 +145,9 @@ memory_enabled = true
|
||||
|
||||
|
||||
def test_app_config_extended_random_keys(tmp_path: os.PathLike) -> None:
|
||||
"""
|
||||
Test that the extended section accepts arbitrary keys,
|
||||
"""Test that the extended section accepts arbitrary keys.
|
||||
|
||||
This test verifies that the extended section accepts arbitrary keys,
|
||||
including ones not defined in any schema.
|
||||
"""
|
||||
config_content = """
|
||||
|
||||
@ -25,7 +25,6 @@ workspace_base = "./workspace"
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
embedding_model = "base-embedding"
|
||||
num_retries = 3
|
||||
|
||||
[llm.custom1]
|
||||
@ -60,28 +59,24 @@ def test_load_from_toml_llm_with_fallback(
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom1 LLM falls back 'num_retries' from base
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # overridden value
|
||||
|
||||
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
|
||||
@ -96,13 +91,11 @@ workspace_base = "./workspace"
|
||||
[llm]
|
||||
model = "base-model"
|
||||
api_key = "base-api-key"
|
||||
embedding_model = "base-embedding"
|
||||
num_retries = 3
|
||||
|
||||
[llm.custom_full]
|
||||
model = "full-custom-model"
|
||||
api_key = "full-custom-api-key"
|
||||
embedding_model = "full-custom-embedding"
|
||||
num_retries = 10
|
||||
"""
|
||||
toml_file = tmp_path / 'full_override_llm.toml'
|
||||
@ -114,14 +107,12 @@ num_retries = 10
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom_full LLM overrides all attributes
|
||||
custom_full = default_config.get_llm_config('custom_full')
|
||||
assert custom_full.model == 'full-custom-model'
|
||||
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
|
||||
assert custom_full.embedding_model == 'full-custom-embedding'
|
||||
assert custom_full.num_retries == 10 # overridden value
|
||||
|
||||
|
||||
@ -137,14 +128,12 @@ def test_load_from_toml_llm_custom_partial_override(
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # Overridden value
|
||||
|
||||
|
||||
@ -156,11 +145,10 @@ def test_load_from_toml_llm_custom_no_override(
|
||||
"""
|
||||
load_from_toml(default_config, generic_llm_toml)
|
||||
|
||||
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
|
||||
# Verify custom3 LLM inherits 'num_retries' from generic
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
|
||||
@ -187,7 +175,6 @@ api_key = "custom-only-api-key"
|
||||
custom_only = default_config.get_llm_config('custom_only')
|
||||
assert custom_only.model == 'custom-only-model'
|
||||
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
|
||||
assert custom_only.embedding_model == 'local' # default value
|
||||
assert custom_only.num_retries == 4 # default value
|
||||
|
||||
|
||||
@ -225,4 +212,3 @@ unknown_attr = "should_not_exist"
|
||||
assert custom_invalid.model == 'base-model'
|
||||
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
|
||||
assert custom_invalid.num_retries == 3 # default value
|
||||
assert custom_invalid.embedding_model == 'local' # default value
|
||||
|
||||
@ -103,5 +103,3 @@ def test_draft_editor_fallback(config_toml_with_draft_editor):
|
||||
draft_editor_config = config.get_llm_config('draft_editor')
|
||||
# num_retries is an example default from llm section
|
||||
assert draft_editor_config.num_retries == 7
|
||||
# embedding_model is defaulted in the LLMConfig class
|
||||
assert draft_editor_config.embedding_model == 'local'
|
||||
|
||||
@ -1,251 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.memory.long_term_memory import LongTermMemory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config() -> LLMConfig:
|
||||
config = MagicMock(spec=LLMConfig)
|
||||
config.embedding_model = 'test_embedding_model'
|
||||
config.api_key = 'test_api_key'
|
||||
config.api_version = 'v1'
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_config() -> AgentConfig:
|
||||
config = AgentConfig(
|
||||
memory_enabled=True,
|
||||
memory_max_threads=4,
|
||||
llm_config='test_llm_config',
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store() -> FileStore:
|
||||
store = MagicMock(spec=FileStore)
|
||||
store.sid = 'test_session'
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream(mock_file_store: FileStore) -> EventStream:
|
||||
with patch('openhands.events.stream.EventStream') as MockEventStream:
|
||||
instance = MockEventStream.return_value
|
||||
instance.sid = 'test_session'
|
||||
instance.get_events = MagicMock()
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(
|
||||
mock_llm_config: LLMConfig,
|
||||
mock_agent_config: AgentConfig,
|
||||
mock_event_stream: EventStream,
|
||||
) -> LongTermMemory:
|
||||
mod = LongTermMemory.__module__
|
||||
with patch(f'{mod}.chromadb.PersistentClient') as mock_chroma_client:
|
||||
mock_collection = MagicMock()
|
||||
mock_chroma_client.return_value.get_or_create_collection.return_value = (
|
||||
mock_collection
|
||||
)
|
||||
with (
|
||||
patch(f'{mod}.ChromaVectorStore', MagicMock()),
|
||||
patch(f'{mod}.EmbeddingsLoader', MagicMock()),
|
||||
patch(f'{mod}.VectorStoreIndex', MagicMock()),
|
||||
):
|
||||
memory = LongTermMemory(
|
||||
llm_config=mock_llm_config,
|
||||
agent_config=mock_agent_config,
|
||||
event_stream=mock_event_stream,
|
||||
)
|
||||
memory.collection = mock_collection
|
||||
return memory
|
||||
|
||||
|
||||
def _create_action_event(action: str) -> Event:
|
||||
"""Helper function to create an action event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = action
|
||||
return event
|
||||
|
||||
|
||||
def _create_observation_event(observation: str) -> Event:
|
||||
"""Helper function to create an observation event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.ENVIRONMENT
|
||||
event.observation = observation
|
||||
return event
|
||||
|
||||
|
||||
def test_add_event_with_action(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_add_event_with_observation(long_term_memory: LongTermMemory):
|
||||
event = _create_observation_event('test_observation')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'observation'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_observation'
|
||||
|
||||
|
||||
def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
|
||||
# Creating an event with additional unexpected attributes
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = 'test_action'
|
||||
event.unexpected_key = 'value'
|
||||
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_events_to_docs_no_events(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
mock_event_stream.get_events.side_effect = FileNotFoundError
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since get_events raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_load_events_into_index_with_invalid_json(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
"""Test loading events with malformed event data."""
|
||||
# Simulate an event that causes event_to_memory to raise a JSONDecodeError
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.event_to_memory',
|
||||
side_effect=json.JSONDecodeError('Expecting value', '', 0),
|
||||
):
|
||||
event = _create_action_event('invalid_action')
|
||||
mock_event_stream.get_events.return_value = [event]
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since event_to_memory raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
long_term_memory._add_document.assert_called()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert 'document' in kwargs
|
||||
assert (
|
||||
kwargs['document'].text
|
||||
== '{"source": "agent", "action": "test_action", "args": {}}'
|
||||
)
|
||||
|
||||
|
||||
def test_search_returns_correct_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [
|
||||
MagicMock(get_text=MagicMock(return_value='result1')),
|
||||
MagicMock(get_text=MagicMock(return_value='result2')),
|
||||
]
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.VectorIndexRetriever',
|
||||
return_value=mock_retriever,
|
||||
):
|
||||
results = long_term_memory.search(query='test query', k=2)
|
||||
assert results == ['result1', 'result2']
|
||||
mock_retriever.retrieve.assert_called_once_with('test query')
|
||||
|
||||
|
||||
def test_search_with_no_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
with patch(
|
||||
'openhands.memory.long_term_memory.VectorIndexRetriever',
|
||||
return_value=mock_retriever,
|
||||
):
|
||||
results = long_term_memory.search(query='no results', k=5)
|
||||
assert results == []
|
||||
mock_retriever.retrieve.assert_called_once_with('no results')
|
||||
|
||||
|
||||
def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
long_term_memory.add_event(event1)
|
||||
long_term_memory.add_event(event2)
|
||||
assert long_term_memory.thought_idx == 2
|
||||
|
||||
|
||||
def test_load_events_batch_insert(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
event3 = _create_action_event('action2')
|
||||
mock_event_stream.get_events.return_value = [event1, event2, event3]
|
||||
|
||||
# Mock insert_batch_docs
|
||||
with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# Mock the insert_batch_docs to simulate document insertion
|
||||
mock_run_docs.return_value = []
|
||||
|
||||
# Call insert_batch_docs with the documents
|
||||
mock_run_docs(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Assert that insert_batch_docs was called with the correct arguments
|
||||
mock_run_docs.assert_called_once_with(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Check if thought_idx was incremented correctly
|
||||
assert long_term_memory.thought_idx == 3
|
||||
@ -30,8 +30,8 @@ def event_stream(temp_dir):
|
||||
@pytest.fixture
|
||||
def agent_configs():
|
||||
return {
|
||||
'CoderAgent': AgentConfig(memory_enabled=True),
|
||||
'BrowsingAgent': AgentConfig(memory_enabled=True),
|
||||
'CoderAgent': AgentConfig(enable_prompt_extensions=True),
|
||||
'BrowsingAgent': AgentConfig(enable_prompt_extensions=True),
|
||||
}
|
||||
|
||||
|
||||
@ -91,8 +91,9 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict
|
||||
|
||||
|
||||
def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: dict):
|
||||
"""When there's no codebase_summary available, there shouldn't be any prompt
|
||||
about 'code summary'
|
||||
"""When there's no codebase_summary available, there shouldn't be any prompt about 'code summary'.
|
||||
|
||||
This test verifies that the prompt doesn't include code summary text when no summary is provided.
|
||||
"""
|
||||
mock_llm = MagicMock()
|
||||
content = json.dumps({'action': 'finish', 'args': {}})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user