mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor embeddings (#4219)
This commit is contained in:
parent
40d2935911
commit
9d0e6a24bc
4
.github/workflows/py-unit-tests.yml
vendored
4
.github/workflows/py-unit-tests.yml
vendored
@ -93,7 +93,7 @@ jobs:
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Run Tests
|
||||
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit
|
||||
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit --ignore=tests/unit/test_memory.py
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
env:
|
||||
@ -125,7 +125,7 @@ jobs:
|
||||
- name: Build Environment
|
||||
run: make build
|
||||
- name: Run Tests
|
||||
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit
|
||||
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
env:
|
||||
|
||||
@ -185,7 +185,7 @@ model = "gpt-4o-mini"
|
||||
#memory_enabled = false
|
||||
|
||||
# Memory maximum threads
|
||||
#memory_max_threads = 2
|
||||
#memory_max_threads = 3
|
||||
|
||||
# LLM config group to use
|
||||
#llm_config = 'your-llm-config-group'
|
||||
|
||||
@ -16,7 +16,7 @@ class AgentConfig:
|
||||
|
||||
micro_agent_name: str | None = None
|
||||
memory_enabled: bool = False
|
||||
memory_max_threads: int = 2
|
||||
memory_max_threads: int = 3
|
||||
llm_config: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
|
||||
@ -96,7 +96,7 @@ def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
|
||||
|
||||
def truncate_content(content: str, max_chars: int) -> str:
|
||||
"""Truncate the middle of the observation content if it is too long."""
|
||||
if len(content) <= max_chars:
|
||||
if len(content) <= max_chars or max_chars == -1:
|
||||
return content
|
||||
|
||||
# truncate the middle and include a message to the LLM about it
|
||||
|
||||
@ -1,189 +1,187 @@
|
||||
import threading
|
||||
import json
|
||||
|
||||
from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_to_memory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.utils.embeddings import (
|
||||
LLAMA_INDEX_AVAILABLE,
|
||||
EmbeddingsLoader,
|
||||
check_llama_index,
|
||||
)
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.utils import json
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
import llama_index.embeddings.openai.base as llama_openai
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
LLAMA_INDEX_AVAILABLE = True
|
||||
except ImportError:
|
||||
LLAMA_INDEX_AVAILABLE = False
|
||||
|
||||
# Conditional imports based on llama_index availability
|
||||
if LLAMA_INDEX_AVAILABLE:
|
||||
# TODO: this could be made configurable
|
||||
num_retries: int = 10
|
||||
retry_min_wait: int = 3
|
||||
retry_max_wait: int = 300
|
||||
|
||||
# llama-index includes a retry decorator around openai.get_embeddings() function
|
||||
# it is initialized with hard-coded values and errors
|
||||
# this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
|
||||
# this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
|
||||
|
||||
if hasattr(llama_openai.get_embeddings, '__wrapped__'):
|
||||
original_get_embeddings = llama_openai.get_embeddings.__wrapped__
|
||||
else:
|
||||
logger.warning('Cannot set custom retry limits.')
|
||||
num_retries = 1
|
||||
original_get_embeddings = llama_openai.get_embeddings
|
||||
|
||||
def attempt_on_error(retry_state):
|
||||
logger.error(
|
||||
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
|
||||
exc_info=False,
|
||||
)
|
||||
return None
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
|
||||
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, InternalServerError)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
import chromadb
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.indices.vector_store.retrievers.retriever import (
|
||||
VectorIndexRetriever,
|
||||
)
|
||||
def wrapper_get_embeddings(*args, **kwargs):
|
||||
return original_get_embeddings(*args, **kwargs)
|
||||
|
||||
llama_openai.get_embeddings = wrapper_get_embeddings
|
||||
|
||||
class EmbeddingsLoader:
|
||||
"""Loader for embedding model initialization."""
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model(strategy: str, llm_config: LLMConfig):
|
||||
supported_ollama_embed_models = [
|
||||
'llama2',
|
||||
'mxbai-embed-large',
|
||||
'nomic-embed-text',
|
||||
'all-minilm',
|
||||
'stable-code',
|
||||
'bge-m3',
|
||||
'bge-large',
|
||||
'paraphrase-multilingual',
|
||||
'snowflake-arctic-embed',
|
||||
]
|
||||
if strategy in supported_ollama_embed_models:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
return OllamaEmbedding(
|
||||
model_name=strategy,
|
||||
base_url=llm_config.embedding_base_url,
|
||||
ollama_additional_kwargs={'mirostat': 0},
|
||||
)
|
||||
elif strategy == 'openai':
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
return AzureOpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
deployment_name=llm_config.embedding_deployment_name,
|
||||
api_key=llm_config.api_key,
|
||||
azure_endpoint=llm_config.base_url,
|
||||
api_version=llm_config.api_version,
|
||||
)
|
||||
elif (strategy is not None) and (strategy.lower() == 'none'):
|
||||
# TODO: this works but is not elegant enough. The incentive is when
|
||||
# an agent using embeddings is not used, there is no reason we need to
|
||||
# initialize an embedding model
|
||||
return None
|
||||
else:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
"""Handles storing information for the agent to access later, using chromadb."""
|
||||
|
||||
def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
|
||||
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
|
||||
if not LLAMA_INDEX_AVAILABLE:
|
||||
raise ImportError(
|
||||
'llama_index and its dependencies are not installed. '
|
||||
'To use LongTermMemory, please run: poetry install --with llama-index'
|
||||
)
|
||||
event_stream: EventStream
|
||||
|
||||
db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
agent_config: AgentConfig,
|
||||
event_stream: EventStream,
|
||||
):
|
||||
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
|
||||
|
||||
check_llama_index()
|
||||
|
||||
# initialize the chromadb client
|
||||
db = chromadb.PersistentClient(
|
||||
path=f'./cache/sessions/{event_stream.sid}/memory',
|
||||
# FIXME anonymized_telemetry=False,
|
||||
)
|
||||
self.collection = db.get_or_create_collection(name='memories')
|
||||
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
||||
|
||||
# embedding model
|
||||
embedding_strategy = llm_config.embedding_model
|
||||
embed_model = EmbeddingsLoader.get_embedding_model(
|
||||
self.embed_model = EmbeddingsLoader.get_embedding_model(
|
||||
embedding_strategy, llm_config
|
||||
)
|
||||
self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
|
||||
self.sema = threading.Semaphore(value=memory_max_threads)
|
||||
self.thought_idx = 0
|
||||
self._add_threads: list[threading.Thread] = []
|
||||
|
||||
def add_event(self, event: dict):
|
||||
# instantiate the index
|
||||
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
|
||||
self.thought_idx = 0
|
||||
|
||||
# initialize the event stream
|
||||
self.event_stream = event_stream
|
||||
|
||||
# max of threads to run the pipeline
|
||||
self.memory_max_threads = agent_config.memory_max_threads
|
||||
|
||||
def add_event(self, event: Event):
|
||||
"""Adds a new event to the long term memory with a unique id.
|
||||
|
||||
Parameters:
|
||||
- event (dict): The new event to be added to memory
|
||||
- event: The new event to be added to memory
|
||||
"""
|
||||
id = ''
|
||||
t = ''
|
||||
if 'action' in event:
|
||||
t = 'action'
|
||||
id = event['action']
|
||||
elif 'observation' in event:
|
||||
t = 'observation'
|
||||
id = event['observation']
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
return
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event),
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': t,
|
||||
'id': id,
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
self.thought_idx += 1
|
||||
logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
|
||||
thread = threading.Thread(target=self._add_doc, args=(doc,))
|
||||
self._add_threads.append(thread)
|
||||
thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
|
||||
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
|
||||
self._add_document(document=doc)
|
||||
|
||||
def _add_doc(self, doc):
|
||||
with self.sema:
|
||||
self.index.insert(doc)
|
||||
def _add_document(self, document: 'Document'):
|
||||
"""Inserts a single document into the index."""
|
||||
self.index.insert_nodes([self._create_node(document)])
|
||||
|
||||
def search(self, query: str, k: int = 10):
|
||||
"""Searches through the current memory using VectorIndexRetriever
|
||||
def _create_node(self, document: 'Document') -> 'TextNode':
|
||||
"""Create a TextNode from a Document instance."""
|
||||
return TextNode(
|
||||
text=document.text,
|
||||
doc_id=document.doc_id,
|
||||
extra_info=document.extra_info,
|
||||
)
|
||||
|
||||
def search(self, query: str, k: int = 10) -> list[str]:
|
||||
"""Searches through the current memory using VectorIndexRetriever.
|
||||
|
||||
Parameters:
|
||||
- query (str): A query to match search results to
|
||||
- k (int): Number of top results to return
|
||||
|
||||
Returns:
|
||||
- list[str]: list of top k results found in current memory
|
||||
- list[str]: List of top k results found in current memory
|
||||
"""
|
||||
retriever = VectorIndexRetriever(
|
||||
index=self.index,
|
||||
similarity_top_k=k,
|
||||
)
|
||||
results = retriever.retrieve(query)
|
||||
|
||||
for result in results:
|
||||
logger.debug(
|
||||
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
|
||||
)
|
||||
|
||||
return [r.get_text() for r in results]
|
||||
|
||||
def _events_to_docs(self) -> list['Document']:
|
||||
"""Convert all events from the EventStream to documents for batch insert into the index."""
|
||||
try:
|
||||
events = self.event_stream.get_events()
|
||||
except Exception as e:
|
||||
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
|
||||
return []
|
||||
|
||||
documents: list[Document] = []
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
# convert the event to a memory-friendly format, and don't truncate
|
||||
event_data = event_to_memory(event, -1)
|
||||
|
||||
# determine the event type and ID
|
||||
event_type = ''
|
||||
event_id = ''
|
||||
if 'action' in event_data:
|
||||
event_type = 'action'
|
||||
event_id = event_data['action']
|
||||
elif 'observation' in event_data:
|
||||
event_type = 'observation'
|
||||
event_id = event_data['observation']
|
||||
|
||||
# create a Document instance for the event
|
||||
doc = Document(
|
||||
text=json.dumps(event_data),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': event_type,
|
||||
'id': event_id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
self.thought_idx += 1
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f'Failed to process event: {e}')
|
||||
continue
|
||||
|
||||
if documents:
|
||||
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
|
||||
else:
|
||||
logger.debug('No valid documents found to insert into the index.')
|
||||
|
||||
return documents
|
||||
|
||||
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
|
||||
"""Create nodes from a list of documents."""
|
||||
return [self._create_node(doc) for doc in documents]
|
||||
|
||||
176
openhands/utils/embeddings.py
Normal file
176
openhands/utils/embeddings.py
Normal file
@ -0,0 +1,176 @@
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
try:
|
||||
# check if those we need later are available using importlib
|
||||
if importlib.util.find_spec('chromadb') is None:
|
||||
raise ImportError(
|
||||
'chromadb is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
if (
|
||||
importlib.util.find_spec(
|
||||
'llama_index.core.indices.vector_store.retrievers.retriever'
|
||||
)
|
||||
is None
|
||||
or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
|
||||
is None
|
||||
):
|
||||
raise ImportError(
|
||||
'llama_index is not available. Please install it using poetry install --with llama-index'
|
||||
)
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
LLAMA_INDEX_AVAILABLE = True
|
||||
|
||||
except ImportError:
|
||||
LLAMA_INDEX_AVAILABLE = False
|
||||
|
||||
# Define supported embedding models
|
||||
SUPPORTED_OLLAMA_EMBED_MODELS = [
|
||||
'llama2',
|
||||
'mxbai-embed-large',
|
||||
'nomic-embed-text',
|
||||
'all-minilm',
|
||||
'stable-code',
|
||||
'bge-m3',
|
||||
'bge-large',
|
||||
'paraphrase-multilingual',
|
||||
'snowflake-arctic-embed',
|
||||
]
|
||||
|
||||
|
||||
def check_llama_index():
|
||||
"""Utility function to check the availability of llama_index.
|
||||
|
||||
Raises:
|
||||
ImportError: If llama_index is not available.
|
||||
"""
|
||||
if not LLAMA_INDEX_AVAILABLE:
|
||||
raise ImportError(
|
||||
'llama_index and its dependencies are not installed. '
|
||||
'To use memory features, please run: poetry install --with llama-index.'
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsLoader:
|
||||
"""Loader for embedding model initialization."""
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
|
||||
"""Initialize and return the appropriate embedding model based on the strategy.
|
||||
|
||||
Parameters:
|
||||
- strategy: The embedding strategy to use.
|
||||
- llm_config: Configuration for the LLM.
|
||||
|
||||
Returns:
|
||||
- An instance of the selected embedding model or None.
|
||||
"""
|
||||
|
||||
if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
return OllamaEmbedding(
|
||||
model_name=strategy,
|
||||
base_url=llm_config.embedding_base_url,
|
||||
ollama_additional_kwargs={'mirostat': 0},
|
||||
)
|
||||
elif strategy == 'openai':
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
return AzureOpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
deployment_name=llm_config.embedding_deployment_name,
|
||||
api_key=llm_config.api_key,
|
||||
azure_endpoint=llm_config.base_url,
|
||||
api_version=llm_config.api_version,
|
||||
)
|
||||
elif (strategy is not None) and (strategy.lower() == 'none'):
|
||||
# TODO: this works but is not elegant enough. The incentive is when
|
||||
# an agent using embeddings is not used, there is no reason we need to
|
||||
# initialize an embedding model
|
||||
return None
|
||||
else:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# initialize the local embedding model
|
||||
local_embed_model = HuggingFaceEmbedding(
|
||||
model_name='BAAI/bge-small-en-v1.5'
|
||||
)
|
||||
|
||||
# for local embeddings, we need torch
|
||||
import torch
|
||||
|
||||
# choose the best device
|
||||
# first determine what is available: CUDA, MPS, or CPU
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ['PYTORCH_FORCE_CPU'] = (
|
||||
'1' # try to force CPU to avoid errors
|
||||
)
|
||||
|
||||
# override CUDA availability
|
||||
torch.cuda.is_available = lambda: False
|
||||
|
||||
# disable MPS to avoid errors
|
||||
if device != 'mps' and hasattr(torch.backends, 'mps'):
|
||||
torch.backends.mps.is_available = lambda: False
|
||||
torch.backends.mps.is_built = False
|
||||
|
||||
# the device being used
|
||||
print(f'Using device for embeddings: {device}')
|
||||
|
||||
return local_embed_model
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Utility functions to run pipelines, split out for profiling
|
||||
# --------------------------------------------------------------------------
|
||||
def run_pipeline(
|
||||
embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run a pipeline embedding documents."""
|
||||
|
||||
# set up a pipeline with the transformations to make
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
embed_model,
|
||||
],
|
||||
)
|
||||
|
||||
# run the pipeline with num_workers
|
||||
nodes = pipeline.run(
|
||||
documents=documents, show_progress=True, num_workers=num_workers
|
||||
)
|
||||
return nodes
|
||||
|
||||
|
||||
def insert_batch_docs(
|
||||
index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
|
||||
) -> list['TextNode']:
|
||||
"""Run the document indexing in parallel."""
|
||||
results = Parallel(n_jobs=num_workers, backend='threading')(
|
||||
delayed(index.insert)(doc) for doc in documents
|
||||
)
|
||||
return results
|
||||
246
tests/unit/test_memory.py
Normal file
246
tests/unit/test_memory.py
Normal file
@ -0,0 +1,246 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.memory.memory import LongTermMemory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config() -> LLMConfig:
|
||||
config = MagicMock(spec=LLMConfig)
|
||||
config.embedding_model = 'test_embedding_model'
|
||||
config.api_key = 'test_api_key'
|
||||
config.api_version = 'v1'
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_config() -> AgentConfig:
|
||||
config = AgentConfig(
|
||||
micro_agent_name='test_micro_agent',
|
||||
memory_enabled=True,
|
||||
memory_max_threads=4,
|
||||
llm_config='test_llm_config',
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store() -> FileStore:
|
||||
store = MagicMock(spec=FileStore)
|
||||
store.sid = 'test_session'
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream(mock_file_store: FileStore) -> EventStream:
|
||||
with patch('openhands.events.stream.EventStream') as MockEventStream:
|
||||
instance = MockEventStream.return_value
|
||||
instance.sid = 'test_session'
|
||||
instance.get_events = MagicMock()
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(
|
||||
mock_llm_config: LLMConfig,
|
||||
mock_agent_config: AgentConfig,
|
||||
mock_event_stream: EventStream,
|
||||
) -> LongTermMemory:
|
||||
with patch(
|
||||
'openhands.memory.memory.chromadb.PersistentClient'
|
||||
) as mock_chroma_client:
|
||||
mock_collection = MagicMock()
|
||||
mock_chroma_client.return_value.get_or_create_collection.return_value = (
|
||||
mock_collection
|
||||
)
|
||||
memory = LongTermMemory(
|
||||
llm_config=mock_llm_config,
|
||||
agent_config=mock_agent_config,
|
||||
event_stream=mock_event_stream,
|
||||
)
|
||||
memory.collection = mock_collection
|
||||
return memory
|
||||
|
||||
|
||||
def _create_action_event(action: str) -> Event:
|
||||
"""Helper function to create an action event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = action
|
||||
return event
|
||||
|
||||
|
||||
def _create_observation_event(observation: str) -> Event:
|
||||
"""Helper function to create an observation event."""
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.USER
|
||||
event.observation = observation
|
||||
return event
|
||||
|
||||
|
||||
def test_add_event_with_action(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_add_event_with_observation(long_term_memory: LongTermMemory):
|
||||
event = _create_observation_event('test_observation')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'observation'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_observation'
|
||||
|
||||
|
||||
def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
|
||||
# Creating an event with additional unexpected attributes
|
||||
event = Event()
|
||||
event._id = -1
|
||||
event._timestamp = datetime.now(timezone.utc).isoformat()
|
||||
event._source = EventSource.AGENT
|
||||
event.action = 'test_action'
|
||||
event.unexpected_key = 'value'
|
||||
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
assert long_term_memory.thought_idx == 1
|
||||
long_term_memory._add_document.assert_called_once()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert kwargs['document'].extra_info['type'] == 'action'
|
||||
assert kwargs['document'].extra_info['id'] == 'test_action'
|
||||
|
||||
|
||||
def test_events_to_docs_no_events(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
mock_event_stream.get_events.side_effect = FileNotFoundError
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since get_events raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_load_events_into_index_with_invalid_json(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
"""Test loading events with malformed event data."""
|
||||
# Simulate an event that causes event_to_memory to raise a JSONDecodeError
|
||||
with patch(
|
||||
'openhands.memory.memory.event_to_memory',
|
||||
side_effect=json.JSONDecodeError('Expecting value', '', 0),
|
||||
):
|
||||
event = _create_action_event('invalid_action')
|
||||
mock_event_stream.get_events.return_value = [event]
|
||||
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# since event_to_memory raises, documents should be empty
|
||||
assert len(documents) == 0
|
||||
|
||||
# thought_idx remains unchanged
|
||||
assert long_term_memory.thought_idx == 0
|
||||
|
||||
|
||||
def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
|
||||
event = _create_action_event('test_action')
|
||||
long_term_memory._add_document = MagicMock()
|
||||
long_term_memory.add_event(event)
|
||||
long_term_memory._add_document.assert_called()
|
||||
_, kwargs = long_term_memory._add_document.call_args
|
||||
assert 'document' in kwargs
|
||||
assert (
|
||||
kwargs['document'].text
|
||||
== '{"source": "agent", "action": "test_action", "args": {}}'
|
||||
)
|
||||
|
||||
|
||||
def test_search_returns_correct_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [
|
||||
MagicMock(get_text=MagicMock(return_value='result1')),
|
||||
MagicMock(get_text=MagicMock(return_value='result2')),
|
||||
]
|
||||
with patch(
|
||||
'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
|
||||
):
|
||||
results = long_term_memory.search(query='test query', k=2)
|
||||
assert results == ['result1', 'result2']
|
||||
mock_retriever.retrieve.assert_called_once_with('test query')
|
||||
|
||||
|
||||
def test_search_with_no_results(long_term_memory: LongTermMemory):
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
with patch(
|
||||
'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
|
||||
):
|
||||
results = long_term_memory.search(query='no results', k=5)
|
||||
assert results == []
|
||||
mock_retriever.retrieve.assert_called_once_with('no results')
|
||||
|
||||
|
||||
def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
long_term_memory.add_event(event1)
|
||||
long_term_memory.add_event(event2)
|
||||
assert long_term_memory.thought_idx == 2
|
||||
|
||||
|
||||
def test_load_events_batch_insert(
|
||||
long_term_memory: LongTermMemory, mock_event_stream: EventStream
|
||||
):
|
||||
event1 = _create_action_event('action1')
|
||||
event2 = _create_observation_event('observation1')
|
||||
event3 = _create_action_event('action2')
|
||||
mock_event_stream.get_events.return_value = [event1, event2, event3]
|
||||
|
||||
# Mock insert_batch_docs
|
||||
with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
|
||||
# convert events to documents
|
||||
documents = long_term_memory._events_to_docs()
|
||||
|
||||
# Mock the insert_batch_docs to simulate document insertion
|
||||
mock_run_docs.return_value = []
|
||||
|
||||
# Call insert_batch_docs with the documents
|
||||
mock_run_docs(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Assert that insert_batch_docs was called with the correct arguments
|
||||
mock_run_docs.assert_called_once_with(
|
||||
index=long_term_memory.index,
|
||||
documents=documents,
|
||||
num_workers=long_term_memory.memory_max_threads,
|
||||
)
|
||||
|
||||
# Check if thought_idx was incremented correctly
|
||||
assert long_term_memory.thought_idx == 3
|
||||
Loading…
x
Reference in New Issue
Block a user