remove llamaindex (#7151)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Robert Brennan 2025-03-11 18:28:56 -04:00 committed by GitHub
parent 924acb182b
commit 5128377baa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 193 additions and 2790 deletions

View File

@ -10,12 +10,6 @@ updates:
pre-commit:
patterns:
- "pre-commit"
llama:
patterns:
- "llama*"
chromadb:
patterns:
- "chromadb"
browsergym:
patterns:
- "browsergym*"

View File

@ -36,7 +36,7 @@ jobs:
python-version: '3.12'
cache: 'poetry'
- name: Install Python dependencies using Poetry
run: poetry install --without evaluation,llama-index
run: poetry install --without evaluation
- name: Build Environment
run: make build
- name: Run tests

View File

@ -54,7 +54,7 @@ jobs:
Hi! I started running the integration tests on your PR. You will receive a comment with the results shortly.
- name: Install Python dependencies using Poetry
run: poetry install --without evaluation,llama-index
run: poetry install --without evaluation
- name: Configure config.toml for testing with Haiku
env:

View File

@ -44,11 +44,11 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
- name: Install Python dependencies using Poetry
run: poetry install --without evaluation,llama-index
run: poetry install --without evaluation
- name: Build Environment
run: make build
- name: Run Tests
run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_long_term_memory.py
run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
env:

View File

@ -133,7 +133,7 @@ install-python-dependencies:
export HNSWLIB_NO_NATIVE=1; \
poetry run pip install chroma-hnswlib; \
fi
@poetry install --without llama-index
@poetry install
@if [ -f "/etc/manjaro-release" ]; then \
echo "$(BLUE)Detected Manjaro Linux. Installing Playwright dependencies...$(RESET)"; \
poetry run pip install playwright; \
@ -265,35 +265,6 @@ setup-config-prompts:
@read -p "Enter your LLM base URL [mostly used for local LLMs, leave blank if not needed - example: http://localhost:5001/v1/]: " llm_base_url; \
if [[ ! -z "$$llm_base_url" ]]; then echo "base_url=\"$$llm_base_url\"" >> $(CONFIG_FILE).tmp; fi
@echo "Enter your LLM Embedding Model"; \
echo "Choices are:"; \
echo " - openai"; \
echo " - azureopenai"; \
echo " - Embeddings available only with OllamaEmbedding:"; \
echo " - llama2"; \
echo " - mxbai-embed-large"; \
echo " - nomic-embed-text"; \
echo " - all-minilm"; \
echo " - stable-code"; \
echo " - bge-m3"; \
echo " - bge-large"; \
echo " - paraphrase-multilingual"; \
echo " - snowflake-arctic-embed"; \
echo " - Leave blank to default to 'BAAI/bge-small-en-v1.5' via huggingface"; \
read -p "> " llm_embedding_model; \
echo "embedding_model=\"$$llm_embedding_model\"" >> $(CONFIG_FILE).tmp; \
if [ "$$llm_embedding_model" = "llama2" ] || [ "$$llm_embedding_model" = "mxbai-embed-large" ] || [ "$$llm_embedding_model" = "nomic-embed-text" ] || [ "$$llm_embedding_model" = "all-minilm" ] || [ "$$llm_embedding_model" = "stable-code" ]; then \
read -p "Enter the local model URL for the embedding model (will set llm.embedding_base_url): " llm_embedding_base_url; \
echo "embedding_base_url=\"$$llm_embedding_base_url\"" >> $(CONFIG_FILE).tmp; \
elif [ "$$llm_embedding_model" = "azureopenai" ]; then \
read -p "Enter the Azure endpoint URL (will overwrite llm.base_url): " llm_base_url; \
echo "base_url=\"$$llm_base_url\"" >> $(CONFIG_FILE).tmp; \
read -p "Enter the Azure LLM Embedding Deployment Name: " llm_embedding_deployment_name; \
echo "embedding_deployment_name=\"$$llm_embedding_deployment_name\"" >> $(CONFIG_FILE).tmp; \
read -p "Enter the Azure API Version: " llm_api_version; \
echo "api_version=\"$$llm_api_version\"" >> $(CONFIG_FILE).tmp; \
fi
# Develop in container
docker-dev:

View File

@ -132,15 +132,6 @@ api_key = ""
# Custom LLM provider
#custom_llm_provider = ""
# Embedding API base URL
#embedding_base_url = ""
# Embedding deployment name
#embedding_deployment_name = ""
# Embedding model to use
embedding_model = "local"
# Maximum number of characters in an observation's content
#max_message_chars = 10000

View File

@ -26,7 +26,7 @@ RUN apt-get update -y \
COPY ./pyproject.toml ./poetry.lock ./
RUN touch README.md
RUN export POETRY_CACHE_DIR && poetry install --without evaluation,llama-index --no-root && rm -rf $POETRY_CACHE_DIR
RUN export POETRY_CACHE_DIR && poetry install --without evaluation --no-root && rm -rf $POETRY_CACHE_DIR
FROM python:3.12.3-slim AS openhands-app

View File

@ -197,21 +197,6 @@ For development setups, you can also define custom named LLM configurations. See
- Default: `""`
- Description: Custom LLM provider
### Embeddings
- `embedding_base_url`
- Type: `str`
- Default: `""`
- Description: Embedding API base URL
- `embedding_deployment_name`
- Type: `str`
- Default: `""`
- Description: Embedding deployment name
- `embedding_model`
- Type: `str`
- Default: `"local"`
- Description: Embedding model to use
### Message Handling
- `max_message_chars`
@ -302,16 +287,6 @@ The agent configuration options are defined in the `[agent]` and `[agent.<agent_
- Default: `""`
- Description: Name of the micro agent to use for this agent
### Memory Configuration
- `memory_enabled`
- Type: `bool`
- Default: `false`
- Description: Whether long-term memory (embeddings) is enabled
- `memory_max_threads`
- Type: `int`
- Default: `3`
- Description: The maximum number of threads indexing at the same time for embeddings
### LLM Configuration
- `llm_config`

View File

@ -31,17 +31,11 @@ You will need your ChatGPT deployment name which can be found on the deployments
- `Base URL` to your Azure API Base URL (e.g. `https://example-endpoint.openai.azure.com`)
- `API Key` to your Azure API key
## Embeddings
OpenHands uses llama-index for embeddings. You can find their documentation on Azure [here](https://docs.llamaindex.ai/en/stable/api_reference/embeddings/azure_openai/).
### Azure OpenAI Configuration
When running OpenHands, set the following environment variables using `-e` in the
When running OpenHands, set the following environment variable using `-e` in the
[docker run command](/modules/usage/installation#start-the-app):
```
LLM_EMBEDDING_MODEL="azureopenai"
LLM_EMBEDDING_DEPLOYMENT_NAME="<your-embedding-deployment-name>" # e.g. "TextEmbedding...<etc>"
LLM_API_VERSION="<api-version>" # e.g. "2024-02-15-preview"
```

View File

@ -14,8 +14,6 @@ class AgentConfig(BaseModel):
codeact_enable_browsing: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling.
codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False.
memory_enabled: Whether long-term memory (embeddings) is enabled.
memory_max_threads: The maximum number of threads indexing at the same time for embeddings. (deprecated)
llm_config: The name of the llm config to use. If specified, this will override global llm config.
enable_prompt_extensions: Whether to use prompt extensions (e.g., microagents, inject runtime info). Default is True.
disabled_microagents: A list of microagents to disable (by name, without .py extension, e.g. ["github", "lint"]). Default is None.
@ -25,8 +23,6 @@ class AgentConfig(BaseModel):
"""
llm_config: str | None = Field(default=None)
memory_enabled: bool = Field(default=False)
memory_max_threads: int = Field(default=3)
codeact_enable_browsing: bool = Field(default=True)
codeact_enable_llm_editor: bool = Field(default=False)
codeact_enable_jupyter: bool = Field(default=True)
@ -50,11 +46,10 @@ class AgentConfig(BaseModel):
Example:
Apply generic agent config with custom agent overrides, e.g.
[agent]
memory_enabled = false
enable_prompt_extensions = true
enable_prompt_extensions = false
[agent.BrowsingAgent]
memory_enabled = true
results in memory_enabled being true for BrowsingAgent but false for others.
enable_prompt_extensions = true
results in prompt_extensions being true for BrowsingAgent but false for others.
Returns:
dict[str, AgentConfig]: A mapping where the key "agent" corresponds to the default configuration

View File

@ -15,11 +15,8 @@ class LLMConfig(BaseModel):
Attributes:
model: The model to use.
api_key: The API key to use.
base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
base_url: The base URL for the API. This is necessary for local LLMs.
api_version: The version of the API.
embedding_model: The embedding model to use.
embedding_base_url: The base URL for the embedding API.
embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
aws_access_key_id: The AWS access key ID.
aws_secret_access_key: The AWS secret access key.
aws_region_name: The AWS region name.
@ -52,9 +49,6 @@ class LLMConfig(BaseModel):
api_key: SecretStr | None = Field(default=None)
base_url: str | None = Field(default=None)
api_version: str | None = Field(default=None)
embedding_model: str = Field(default='local')
embedding_base_url: str | None = Field(default=None)
embedding_deployment_name: str | None = Field(default=None)
aws_access_key_id: SecretStr | None = Field(default=None)
aws_secret_access_key: SecretStr | None = Field(default=None)
aws_region_name: str | None = Field(default=None)

View File

@ -282,8 +282,6 @@ def finalize_config(cfg: AppConfig):
# make sure log_completions_folder is an absolute path
for llm in cfg.llms.values():
llm.log_completions_folder = os.path.abspath(llm.log_completions_folder)
if llm.embedding_base_url is None:
llm.embedding_base_url = llm.base_url
if cfg.sandbox.use_host_network and platform.system() == 'Darwin':
logger.openhands_logger.warning(

View File

@ -2,7 +2,6 @@
- Short Term History
- Memory Condenser
- Long Term Memory
## Short Term History
- Short term history filters the event stream and computes the messages that are injected into the context
@ -17,7 +16,3 @@
- Then it does the same for later chunks of events between user messages
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
## Long Term Memory
- Long term memory component stores embeddings for events and prompts in a vector store
- The agent can query it when it needs detailed information about a past event or to learn new actions

View File

@ -1,4 +1,3 @@
from openhands.memory.condenser import Condenser
from openhands.memory.long_term_memory import LongTermMemory
__all__ = ['LongTermMemory', 'Condenser']
__all__ = ['Condenser']

View File

@ -1,188 +0,0 @@
import json
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_memory
from openhands.events.stream import EventStream
from openhands.utils.embeddings import (
LLAMA_INDEX_AVAILABLE,
EmbeddingsLoader,
check_llama_index,
)
# Conditional imports based on llama_index availability
if LLAMA_INDEX_AVAILABLE:
import chromadb
from llama_index.core import Document
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class LongTermMemory:
"""Handles storing information for the agent to access later, using chromadb."""
event_stream: EventStream
def __init__(
self,
llm_config: LLMConfig,
agent_config: AgentConfig,
event_stream: EventStream,
):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
check_llama_index()
# initialize the chromadb client
db = chromadb.PersistentClient(
path=f'./cache/sessions/{event_stream.sid}/memory',
# FIXME anonymized_telemetry=False,
)
self.collection = db.get_or_create_collection(name='memories')
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# embedding model
embedding_strategy = llm_config.embedding_model
self.embed_model = EmbeddingsLoader.get_embedding_model(
embedding_strategy, llm_config
)
logger.debug(f'Using embedding model: {self.embed_model}')
# instantiate the index
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
self.thought_idx = 0
# initialize the event stream
self.event_stream = event_stream
# max of threads to run the pipeline
self.memory_max_threads = agent_config.memory_max_threads
def add_event(self, event: Event):
"""Adds a new event to the long term memory with a unique id.
Parameters:
- event: The new event to be added to memory
"""
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
return
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
self.thought_idx += 1
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
self._add_document(document=doc)
def _add_document(self, document: 'Document'):
"""Inserts a single document into the index."""
self.index.insert_nodes([self._create_node(document)])
def _create_node(self, document: 'Document') -> 'TextNode':
"""Create a TextNode from a Document instance."""
return TextNode(
text=document.text,
doc_id=document.doc_id,
extra_info=document.extra_info,
)
def search(self, query: str, k: int = 10) -> list[str]:
"""Searches through the current memory using VectorIndexRetriever.
Parameters:
- query (str): A query to match search results to
- k (int): Number of top results to return
Returns:
- list[str]: List of top k results found in current memory
"""
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=k,
)
results = retriever.retrieve(query)
for result in results:
logger.debug(
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
)
return [r.get_text() for r in results]
def _events_to_docs(self) -> list['Document']:
"""Convert all events from the EventStream to documents for batch insert into the index."""
try:
events = self.event_stream.get_events()
except Exception as e:
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
return []
documents: list[Document] = []
for event in events:
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
documents.append(doc)
self.thought_idx += 1
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
continue
if documents:
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
else:
logger.debug('No valid documents found to insert into the index.')
return documents
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
"""Create nodes from a list of documents."""
return [self._create_node(doc) for doc in documents]

View File

@ -1,187 +0,0 @@
import importlib.util
import os
from joblib import Parallel, delayed
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
try:
# check if those we need later are available using importlib
if importlib.util.find_spec('chromadb') is None:
raise ImportError(
'chromadb is not available. Please install it using poetry install --with llama-index'
)
if (
importlib.util.find_spec(
'llama_index.core.indices.vector_store.retrievers.retriever'
)
is None
or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
is None
):
raise ImportError(
'llama_index is not available. Please install it using poetry install --with llama-index'
)
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.schema import TextNode
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
# Define supported embedding models
SUPPORTED_OLLAMA_EMBED_MODELS = [
'llama2',
'mxbai-embed-large',
'nomic-embed-text',
'all-minilm',
'stable-code',
'bge-m3',
'bge-large',
'paraphrase-multilingual',
'snowflake-arctic-embed',
]
def check_llama_index():
"""Utility function to check the availability of llama_index.
Raises:
ImportError: If llama_index is not available.
"""
if not LLAMA_INDEX_AVAILABLE:
raise ImportError(
'llama_index and its dependencies are not installed. '
'To use memory features, please run: poetry install --with llama-index.'
)
class EmbeddingsLoader:
"""Loader for embedding model initialization."""
@staticmethod
def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
"""Initialize and return the appropriate embedding model based on the strategy.
Parameters:
- strategy: The embedding strategy to use.
- llm_config: Configuration for the LLM.
Returns:
- An instance of the selected embedding model or None.
"""
if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
from llama_index.embeddings.ollama import OllamaEmbedding
return OllamaEmbedding(
model_name=strategy,
base_url=llm_config.embedding_base_url,
ollama_additional_kwargs={'mirostat': 0},
)
elif strategy == 'openai':
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key.get_secret_value()
if llm_config.api_key
else None,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
return AzureOpenAIEmbedding(
model='text-embedding-ada-002',
deployment_name=llm_config.embedding_deployment_name,
api_key=llm_config.api_key.get_secret_value()
if llm_config.api_key
else None,
azure_endpoint=llm_config.base_url,
api_version=llm_config.api_version,
)
elif strategy == 'voyage':
from llama_index.embeddings.voyageai import VoyageEmbedding
return VoyageEmbedding(
model_name='voyage-code-3',
)
elif (strategy is not None) and (strategy.lower() == 'none'):
# TODO: this works but is not elegant enough. The incentive is when
# an agent using embeddings is not used, there is no reason we need to
# initialize an embedding model
return None
else:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# initialize the local embedding model
local_embed_model = HuggingFaceEmbedding(
model_name='BAAI/bge-small-en-v1.5'
)
# for local embeddings, we need torch
import torch
# choose the best device
# first determine what is available: CUDA, MPS, or CPU
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = 'mps'
else:
device = 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['PYTORCH_FORCE_CPU'] = (
'1' # try to force CPU to avoid errors
)
# override CUDA availability
torch.cuda.is_available = lambda: False
# disable MPS to avoid errors
if device != 'mps' and hasattr(torch.backends, 'mps'):
torch.backends.mps.is_available = lambda: False
torch.backends.mps.is_built = False
# the device being used
logger.debug(f'Using device for embeddings: {device}')
return local_embed_model
# --------------------------------------------------------------------------
# Utility functions to run pipelines, split out for profiling
# --------------------------------------------------------------------------
def run_pipeline(
embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
) -> list['TextNode']:
"""Run a pipeline embedding documents."""
# set up a pipeline with the transformations to make
pipeline = IngestionPipeline(
transformations=[
embed_model,
],
)
# run the pipeline with num_workers
nodes = pipeline.run(
documents=documents, show_progress=True, num_workers=num_workers
)
return nodes
def insert_batch_docs(
index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
) -> list['TextNode']:
"""Run the document indexing in parallel."""
results = Parallel(n_jobs=num_workers, backend='threading')(
delayed(index.insert)(doc) for doc in documents
)
return results

2053
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -79,17 +79,6 @@ memory-profiler = "^0.61.0"
daytona-sdk = "0.9.1"
python-json-logger = "^3.2.1"
[tool.poetry.group.llama-index.dependencies]
llama-index = "*"
llama-index-vector-stores-chroma = "*"
chromadb = "*"
llama-index-embeddings-huggingface = "*"
torch = "2.5.1"
llama-index-embeddings-azure-openai = "*"
llama-index-embeddings-ollama = "*"
voyageai = "*"
llama-index-embeddings-voyageai = "*"
[tool.poetry.group.dev.dependencies]
ruff = "0.9.8"
mypy = "1.15.0"

View File

@ -57,8 +57,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
monkeypatch.setenv('WORKSPACE_BASE', '/repos/openhands/workspace')
monkeypatch.setenv('LLM_API_KEY', 'sk-proj-rgMV0...')
monkeypatch.setenv('LLM_MODEL', 'gpt-4o')
monkeypatch.setenv('AGENT_MEMORY_MAX_THREADS', '4')
monkeypatch.setenv('AGENT_MEMORY_ENABLED', 'True')
monkeypatch.setenv('DEFAULT_AGENT', 'CodeActAgent')
monkeypatch.setenv('SANDBOX_TIMEOUT', '10')
@ -70,9 +68,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...'
assert config.get_llm_config().model == 'gpt-4o'
assert isinstance(config.get_agent_config(), AgentConfig)
assert isinstance(config.get_agent_config().memory_max_threads, int)
assert config.get_agent_config().memory_max_threads == 4
assert config.get_agent_config().memory_enabled is True
assert config.default_agent == 'CodeActAgent'
assert config.sandbox.timeout == 10
@ -80,7 +75,6 @@ def test_compat_env_to_config(monkeypatch, setup_env):
def test_load_from_old_style_env(monkeypatch, default_config):
# Test loading configuration from old-style environment variables using monkeypatch
monkeypatch.setenv('LLM_API_KEY', 'test-api-key')
monkeypatch.setenv('AGENT_MEMORY_ENABLED', 'True')
monkeypatch.setenv('DEFAULT_AGENT', 'BrowsingAgent')
monkeypatch.setenv('WORKSPACE_BASE', '/opt/files/workspace')
monkeypatch.setenv('SANDBOX_BASE_CONTAINER_IMAGE', 'custom_image')
@ -88,7 +82,6 @@ def test_load_from_old_style_env(monkeypatch, default_config):
load_from_env(default_config, os.environ)
assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key'
assert default_config.get_agent_config().memory_enabled is True
assert default_config.default_agent == 'BrowsingAgent'
assert default_config.workspace_base == '/opt/files/workspace'
assert default_config.workspace_mount_path is None # before finalize_config
@ -110,11 +103,11 @@ model = "some-cheap-model"
api_key = "cheap-model-api-key"
[agent]
memory_enabled = true
enable_prompt_extensions = true
[agent.BrowsingAgent]
llm_config = "cheap"
memory_enabled = false
enable_prompt_extensions = false
[sandbox]
timeout = 1
@ -131,14 +124,16 @@ default_agent = "TestAgent"
assert default_config.default_agent == 'TestAgent'
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
assert default_config.get_agent_config().memory_enabled is True
assert default_config.get_agent_config().enable_prompt_extensions is True
# undefined agent config inherits default ones
assert (
default_config.get_llm_config_from_agent('CodeActAgent')
== default_config.get_llm_config()
)
assert default_config.get_agent_config('CodeActAgent').memory_enabled is True
assert (
default_config.get_agent_config('CodeActAgent').enable_prompt_extensions is True
)
# defined agent config overrides default ones
assert default_config.get_llm_config_from_agent(
@ -148,7 +143,10 @@ default_agent = "TestAgent"
default_config.get_llm_config_from_agent('BrowsingAgent').model
== 'some-cheap-model'
)
assert default_config.get_agent_config('BrowsingAgent').memory_enabled is False
assert (
default_config.get_agent_config('BrowsingAgent').enable_prompt_extensions
is False
)
assert default_config.workspace_base == '/opt/files2/workspace'
assert default_config.sandbox.timeout == 1
@ -438,7 +436,7 @@ def test_core_not_in_toml(default_config, temp_toml_file):
model = "test-model"
[agent]
memory_enabled = true
enable_prompt_extensions = true
[sandbox]
timeout = 1
@ -451,7 +449,7 @@ security_analyzer = "semgrep"
load_from_toml(default_config, temp_toml_file)
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_agent_config().memory_enabled is True
assert default_config.get_agent_config().enable_prompt_extensions is True
assert default_config.sandbox.base_container_image == 'custom_image'
assert default_config.sandbox.user_id == 1001
assert default_config.security.security_analyzer == 'semgrep'
@ -476,7 +474,7 @@ invalid_field = "test"
model = "gpt-4"
[agent]
memory_enabled = true
enable_prompt_extensions = true
[sandbox]
invalid_field_in_sandbox = "test"
@ -553,15 +551,6 @@ def test_workspace_mount_rewrite(default_config, monkeypatch):
assert default_config.workspace_mount_path == '/sandbox/project'
def test_embedding_base_url_default(default_config):
default_config.get_llm_config().base_url = 'https://api.exampleapi.com'
finalize_config(default_config)
assert (
default_config.get_llm_config().embedding_base_url
== 'https://api.exampleapi.com'
)
def test_cache_dir_creation(default_config, tmpdir):
default_config.cache_dir = str(tmpdir.join('test_cache'))
finalize_config(default_config)
@ -845,7 +834,9 @@ def test_api_keys_repr_str():
# Test AgentConfig
# No attrs in AgentConfig have 'key' or 'token' in their name
agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4)
agent_config = AgentConfig(
enable_prompt_extensions=True, codeact_enable_browsing=False
)
for attr_name in AgentConfig.model_fields.keys():
if not attr_name.startswith('__'):
assert (
@ -927,12 +918,10 @@ max_budget_per_task = 4.0
[llm.gpt3]
model="gpt-3.5-turbo"
api_key="redacted"
embedding_model="openai"
[llm.gpt4o]
model="gpt-4o"
api_key="redacted"
embedding_model="openai"
"""
with open(temp_toml_file, 'w') as f:
@ -940,7 +929,6 @@ embedding_model="openai"
llm_config = get_llm_config_arg('gpt3', temp_toml_file)
assert llm_config.model == 'gpt-3.5-turbo'
assert llm_config.embedding_model == 'openai'
def test_get_agent_configs(default_config, temp_toml_file):
@ -950,10 +938,10 @@ max_iterations = 100
max_budget_per_task = 4.0
[agent.CodeActAgent]
memory_enabled = true
enable_prompt_extensions = true
[agent.BrowsingAgent]
memory_max_threads = 10
codeact_enable_jupyter = false
"""
with open(temp_toml_file, 'w') as f:
@ -962,9 +950,9 @@ memory_max_threads = 10
load_from_toml(default_config, temp_toml_file)
codeact_config = default_config.get_agent_configs().get('CodeActAgent')
assert codeact_config.memory_enabled is True
assert codeact_config.enable_prompt_extensions is True
browsing_config = default_config.get_agent_configs().get('BrowsingAgent')
assert browsing_config.memory_max_threads == 10
assert browsing_config.codeact_enable_jupyter is False
def test_get_agent_config_arg(temp_toml_file):
@ -974,26 +962,24 @@ max_iterations = 100
max_budget_per_task = 4.0
[agent.CodeActAgent]
memory_enabled = true
enable_prompt_extensions = false
codeact_enable_browsing = false
[agent.BrowsingAgent]
memory_enabled = false
enable_prompt_extensions = true
memory_max_threads = 10
codeact_enable_jupyter = false
"""
with open(temp_toml_file, 'w') as f:
f.write(temp_toml)
agent_config = get_agent_config_arg('CodeActAgent', temp_toml_file)
assert agent_config.memory_enabled
assert not agent_config.enable_prompt_extensions
assert not agent_config.codeact_enable_browsing
agent_config2 = get_agent_config_arg('BrowsingAgent', temp_toml_file)
assert not agent_config2.memory_enabled
assert agent_config2.enable_prompt_extensions
assert agent_config2.memory_max_threads == 10
assert not agent_config2.codeact_enable_jupyter
def test_agent_config_custom_group_name(temp_toml_file):
@ -1002,10 +988,10 @@ def test_agent_config_custom_group_name(temp_toml_file):
max_iterations = 99
[agent.group1]
memory_enabled = true
enable_prompt_extensions = true
[agent.group2]
memory_enabled = false
enable_prompt_extensions = false
"""
with open(temp_toml_file, 'w') as f:
f.write(temp_toml)
@ -1017,9 +1003,9 @@ memory_enabled = false
# run_infer in evaluation can use `get_agent_config_arg` to load custom
# agent configs with any group name (not just agent name)
agent_config1 = get_agent_config_arg('group1', temp_toml_file)
assert agent_config1.memory_enabled
assert agent_config1.enable_prompt_extensions
agent_config2 = get_agent_config_arg('group2', temp_toml_file)
assert not agent_config2.memory_enabled
assert not agent_config2.enable_prompt_extensions
def test_agent_config_from_toml_section():
@ -1028,11 +1014,10 @@ def test_agent_config_from_toml_section():
# Test with base config and custom configs
agent_section = {
'memory_enabled': True,
'memory_max_threads': 5,
'enable_prompt_extensions': True,
'CustomAgent1': {'memory_enabled': False, 'codeact_enable_browsing': False},
'CustomAgent2': {'memory_max_threads': 10, 'enable_prompt_extensions': False},
'codeact_enable_browsing': True,
'CustomAgent1': {'codeact_enable_browsing': False},
'CustomAgent2': {'enable_prompt_extensions': False},
'InvalidAgent': {
'invalid_field': 'some_value' # This should be skipped but not affect others
},
@ -1043,20 +1028,16 @@ def test_agent_config_from_toml_section():
# Verify the base config was correctly parsed
assert 'agent' in result
assert result['agent'].memory_enabled is True
assert result['agent'].memory_max_threads == 5
assert result['agent'].enable_prompt_extensions is True
assert result['agent'].codeact_enable_browsing is True
# Verify custom configs were correctly parsed and inherit from base
assert 'CustomAgent1' in result
assert result['CustomAgent1'].memory_enabled is False # Overridden
assert result['CustomAgent1'].memory_max_threads == 5 # Inherited
assert result['CustomAgent1'].codeact_enable_browsing is False # Overridden
assert result['CustomAgent1'].enable_prompt_extensions is True # Inherited
assert 'CustomAgent2' in result
assert result['CustomAgent2'].memory_enabled is True # Inherited
assert result['CustomAgent2'].memory_max_threads == 10 # Overridden
assert result['CustomAgent2'].codeact_enable_browsing is True # Inherited
assert result['CustomAgent2'].enable_prompt_extensions is False # Overridden
# Verify the invalid config was skipped
@ -1070,8 +1051,11 @@ def test_agent_config_from_toml_section_with_invalid_base():
# Test with invalid base config but valid custom configs
agent_section = {
'invalid_field': 'some_value', # This should be ignored in base config
'memory_max_threads': 'not_an_int', # This should cause validation error
'CustomAgent': {'memory_enabled': True, 'memory_max_threads': 8},
'codeact_enable_jupyter': 'not_a_bool', # This should cause validation error
'CustomAgent': {
'codeact_enable_browsing': False,
'codeact_enable_jupyter': True,
},
}
# Parse the section
@ -1079,10 +1063,10 @@ def test_agent_config_from_toml_section_with_invalid_base():
# Verify a default base config was created despite the invalid fields
assert 'agent' in result
assert result['agent'].memory_enabled is False # Default value
assert result['agent'].memory_max_threads == 3 # Default value
assert result['agent'].codeact_enable_browsing is True # Default value
assert result['agent'].codeact_enable_jupyter is True # Default value
# Verify custom config was still processed correctly
assert 'CustomAgent' in result
assert result['CustomAgent'].memory_enabled is True
assert result['CustomAgent'].memory_max_threads == 8
assert result['CustomAgent'].codeact_enable_browsing is False
assert result['CustomAgent'].codeact_enable_jupyter is True

View File

@ -8,8 +8,9 @@ from openhands.core.config.utils import load_from_toml
def test_extended_config_from_dict():
"""
Test that ExtendedConfig.from_dict successfully creates an instance
"""Test that ExtendedConfig.from_dict successfully creates an instance.
This test verifies that the from_dict method correctly creates an instance
from a dictionary containing arbitrary extra keys.
"""
data = {'foo': 'bar', 'baz': 123, 'flag': True}
@ -24,9 +25,7 @@ def test_extended_config_from_dict():
def test_extended_config_empty():
"""
Test that an empty ExtendedConfig can be created and accessed.
"""
"""Test that an empty ExtendedConfig can be created and accessed."""
ext_cfg = ExtendedConfig.from_dict({})
assert ext_cfg.root == {}
@ -36,9 +35,10 @@ def test_extended_config_empty():
def test_extended_config_str_and_repr():
"""
Test that __str__ and __repr__ return the correct string representations
of the ExtendedConfig instance.
"""Test that __str__ and __repr__ return the correct string representations.
This test verifies that the string representations of the ExtendedConfig instance
include the expected key/value pairs.
"""
data = {'alpha': 'test', 'beta': 42}
ext_cfg = ExtendedConfig.from_dict(data)
@ -54,9 +54,10 @@ def test_extended_config_str_and_repr():
def test_extended_config_getitem_and_getattr():
"""
Test that __getitem__ and __getattr__ can be used to access values
in the ExtendedConfig instance.
"""Test that __getitem__ and __getattr__ can be used to access values.
This test verifies that values in the ExtendedConfig instance can be accessed
both via attribute access and dictionary-style access.
"""
data = {'key1': 'value1', 'key2': 2}
ext_cfg = ExtendedConfig.from_dict(data)
@ -68,9 +69,7 @@ def test_extended_config_getitem_and_getattr():
def test_extended_config_invalid_key():
"""
Test that accessing a non-existent key via attribute access raises AttributeError.
"""
"""Test that accessing a non-existent key via attribute access raises AttributeError."""
data = {'existing': 'yes'}
ext_cfg = ExtendedConfig.from_dict(data)
@ -82,9 +81,10 @@ def test_extended_config_invalid_key():
def test_app_config_extended_from_toml(tmp_path: os.PathLike) -> None:
"""
Test that the [extended] section in a TOML file is correctly loaded into
AppConfig.extended and that it accepts arbitrary keys.
"""Test that the [extended] section in a TOML file is correctly loaded.
This test verifies that the [extended] section is loaded into AppConfig.extended
and that it accepts arbitrary keys.
"""
# Create a temporary TOML file with multiple sections including [extended]
config_content = """
@ -101,7 +101,7 @@ custom2 = 42
llm = "overridden" # even a key like 'llm' is accepted in extended
[agent]
memory_enabled = true
enable_prompt_extensions = true
"""
config_file = tmp_path / 'config.toml'
config_file.write_text(config_content)
@ -118,8 +118,9 @@ memory_enabled = true
def test_app_config_extended_default(tmp_path: os.PathLike) -> None:
"""
Test that if there is no [extended] section in the TOML file,
"""Test default behavior when no [extended] section exists.
This test verifies that if there is no [extended] section in the TOML file,
AppConfig.extended remains its default (empty) ExtendedConfig.
"""
config_content = """
@ -131,7 +132,7 @@ model = "test-model"
api_key = "toml-api-key"
[agent]
memory_enabled = true
enable_prompt_extensions = true
"""
config_file = tmp_path / 'config.toml'
config_file.write_text(config_content)
@ -144,8 +145,9 @@ memory_enabled = true
def test_app_config_extended_random_keys(tmp_path: os.PathLike) -> None:
"""
Test that the extended section accepts arbitrary keys,
"""Test that the extended section accepts arbitrary keys.
This test verifies that the extended section accepts arbitrary keys,
including ones not defined in any schema.
"""
config_content = """

View File

@ -25,7 +25,6 @@ workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom1]
@ -60,28 +59,24 @@ def test_load_from_toml_llm_with_fallback(
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom1 LLM falls back 'num_retries' from base
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # overridden value
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
@ -96,13 +91,11 @@ workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom_full]
model = "full-custom-model"
api_key = "full-custom-api-key"
embedding_model = "full-custom-embedding"
num_retries = 10
"""
toml_file = tmp_path / 'full_override_llm.toml'
@ -114,14 +107,12 @@ num_retries = 10
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom_full LLM overrides all attributes
custom_full = default_config.get_llm_config('custom_full')
assert custom_full.model == 'full-custom-model'
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
assert custom_full.embedding_model == 'full-custom-embedding'
assert custom_full.num_retries == 10 # overridden value
@ -137,14 +128,12 @@ def test_load_from_toml_llm_custom_partial_override(
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # Overridden value
@ -156,11 +145,10 @@ def test_load_from_toml_llm_custom_no_override(
"""
load_from_toml(default_config, generic_llm_toml)
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
# Verify custom3 LLM inherits 'num_retries' from generic
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
@ -187,7 +175,6 @@ api_key = "custom-only-api-key"
custom_only = default_config.get_llm_config('custom_only')
assert custom_only.model == 'custom-only-model'
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
assert custom_only.embedding_model == 'local' # default value
assert custom_only.num_retries == 4 # default value
@ -225,4 +212,3 @@ unknown_attr = "should_not_exist"
assert custom_invalid.model == 'base-model'
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
assert custom_invalid.num_retries == 3 # default value
assert custom_invalid.embedding_model == 'local' # default value

View File

@ -103,5 +103,3 @@ def test_draft_editor_fallback(config_toml_with_draft_editor):
draft_editor_config = config.get_llm_config('draft_editor')
# num_retries is an example default from llm section
assert draft_editor_config.num_retries == 7
# embedding_model is defaulted in the LLMConfig class
assert draft_editor_config.embedding_model == 'local'

View File

@ -1,251 +0,0 @@
import json
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.memory.long_term_memory import LongTermMemory
from openhands.storage.files import FileStore
@pytest.fixture
def mock_llm_config() -> LLMConfig:
config = MagicMock(spec=LLMConfig)
config.embedding_model = 'test_embedding_model'
config.api_key = 'test_api_key'
config.api_version = 'v1'
return config
@pytest.fixture
def mock_agent_config() -> AgentConfig:
config = AgentConfig(
memory_enabled=True,
memory_max_threads=4,
llm_config='test_llm_config',
)
return config
@pytest.fixture
def mock_file_store() -> FileStore:
store = MagicMock(spec=FileStore)
store.sid = 'test_session'
return store
@pytest.fixture
def mock_event_stream(mock_file_store: FileStore) -> EventStream:
with patch('openhands.events.stream.EventStream') as MockEventStream:
instance = MockEventStream.return_value
instance.sid = 'test_session'
instance.get_events = MagicMock()
return instance
@pytest.fixture
def long_term_memory(
mock_llm_config: LLMConfig,
mock_agent_config: AgentConfig,
mock_event_stream: EventStream,
) -> LongTermMemory:
mod = LongTermMemory.__module__
with patch(f'{mod}.chromadb.PersistentClient') as mock_chroma_client:
mock_collection = MagicMock()
mock_chroma_client.return_value.get_or_create_collection.return_value = (
mock_collection
)
with (
patch(f'{mod}.ChromaVectorStore', MagicMock()),
patch(f'{mod}.EmbeddingsLoader', MagicMock()),
patch(f'{mod}.VectorStoreIndex', MagicMock()),
):
memory = LongTermMemory(
llm_config=mock_llm_config,
agent_config=mock_agent_config,
event_stream=mock_event_stream,
)
memory.collection = mock_collection
return memory
def _create_action_event(action: str) -> Event:
"""Helper function to create an action event."""
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.AGENT
event.action = action
return event
def _create_observation_event(observation: str) -> Event:
"""Helper function to create an observation event."""
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.ENVIRONMENT
event.observation = observation
return event
def test_add_event_with_action(long_term_memory: LongTermMemory):
event = _create_action_event('test_action')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'action'
assert kwargs['document'].extra_info['id'] == 'test_action'
def test_add_event_with_observation(long_term_memory: LongTermMemory):
event = _create_observation_event('test_observation')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'observation'
assert kwargs['document'].extra_info['id'] == 'test_observation'
def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
# Creating an event with additional unexpected attributes
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.AGENT
event.action = 'test_action'
event.unexpected_key = 'value'
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'action'
assert kwargs['document'].extra_info['id'] == 'test_action'
def test_events_to_docs_no_events(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
mock_event_stream.get_events.side_effect = FileNotFoundError
# convert events to documents
documents = long_term_memory._events_to_docs()
# since get_events raises, documents should be empty
assert len(documents) == 0
# thought_idx remains unchanged
assert long_term_memory.thought_idx == 0
def test_load_events_into_index_with_invalid_json(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
"""Test loading events with malformed event data."""
# Simulate an event that causes event_to_memory to raise a JSONDecodeError
with patch(
'openhands.memory.long_term_memory.event_to_memory',
side_effect=json.JSONDecodeError('Expecting value', '', 0),
):
event = _create_action_event('invalid_action')
mock_event_stream.get_events.return_value = [event]
# convert events to documents
documents = long_term_memory._events_to_docs()
# since event_to_memory raises, documents should be empty
assert len(documents) == 0
# thought_idx remains unchanged
assert long_term_memory.thought_idx == 0
def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
event = _create_action_event('test_action')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
long_term_memory._add_document.assert_called()
_, kwargs = long_term_memory._add_document.call_args
assert 'document' in kwargs
assert (
kwargs['document'].text
== '{"source": "agent", "action": "test_action", "args": {}}'
)
def test_search_returns_correct_results(long_term_memory: LongTermMemory):
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [
MagicMock(get_text=MagicMock(return_value='result1')),
MagicMock(get_text=MagicMock(return_value='result2')),
]
with patch(
'openhands.memory.long_term_memory.VectorIndexRetriever',
return_value=mock_retriever,
):
results = long_term_memory.search(query='test query', k=2)
assert results == ['result1', 'result2']
mock_retriever.retrieve.assert_called_once_with('test query')
def test_search_with_no_results(long_term_memory: LongTermMemory):
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = []
with patch(
'openhands.memory.long_term_memory.VectorIndexRetriever',
return_value=mock_retriever,
):
results = long_term_memory.search(query='no results', k=5)
assert results == []
mock_retriever.retrieve.assert_called_once_with('no results')
def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
event1 = _create_action_event('action1')
event2 = _create_observation_event('observation1')
long_term_memory.add_event(event1)
long_term_memory.add_event(event2)
assert long_term_memory.thought_idx == 2
def test_load_events_batch_insert(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
event1 = _create_action_event('action1')
event2 = _create_observation_event('observation1')
event3 = _create_action_event('action2')
mock_event_stream.get_events.return_value = [event1, event2, event3]
# Mock insert_batch_docs
with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
# convert events to documents
documents = long_term_memory._events_to_docs()
# Mock the insert_batch_docs to simulate document insertion
mock_run_docs.return_value = []
# Call insert_batch_docs with the documents
mock_run_docs(
index=long_term_memory.index,
documents=documents,
num_workers=long_term_memory.memory_max_threads,
)
# Assert that insert_batch_docs was called with the correct arguments
mock_run_docs.assert_called_once_with(
index=long_term_memory.index,
documents=documents,
num_workers=long_term_memory.memory_max_threads,
)
# Check if thought_idx was incremented correctly
assert long_term_memory.thought_idx == 3

View File

@ -30,8 +30,8 @@ def event_stream(temp_dir):
@pytest.fixture
def agent_configs():
return {
'CoderAgent': AgentConfig(memory_enabled=True),
'BrowsingAgent': AgentConfig(memory_enabled=True),
'CoderAgent': AgentConfig(enable_prompt_extensions=True),
'BrowsingAgent': AgentConfig(enable_prompt_extensions=True),
}
@ -91,8 +91,9 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict
def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: dict):
"""When there's no codebase_summary available, there shouldn't be any prompt
about 'code summary'
"""When there's no codebase_summary available, there shouldn't be any prompt about 'code summary'.
This test verifies that the prompt doesn't include code summary text when no summary is provided.
"""
mock_llm = MagicMock()
content = json.dumps({'action': 'finish', 'args': {}})