Refactor embeddings (#4219)

This commit is contained in:
Engel Nyst 2024-10-05 20:59:08 +02:00 committed by GitHub
parent 40d2935911
commit 9d0e6a24bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 571 additions and 151 deletions

View File

@ -93,7 +93,7 @@ jobs:
id: buildx
uses: docker/setup-buildx-action@v3
- name: Run Tests
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit --ignore=tests/unit/test_memory.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
env:
@ -125,7 +125,7 @@ jobs:
- name: Build Environment
run: make build
- name: Run Tests
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit
run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
env:

View File

@ -185,7 +185,7 @@ model = "gpt-4o-mini"
#memory_enabled = false
# Memory maximum threads
#memory_max_threads = 2
#memory_max_threads = 3
# LLM config group to use
#llm_config = 'your-llm-config-group'

View File

@ -16,7 +16,7 @@ class AgentConfig:
micro_agent_name: str | None = None
memory_enabled: bool = False
memory_max_threads: int = 2
memory_max_threads: int = 3
llm_config: str | None = None
def defaults_to_dict(self) -> dict:

View File

@ -96,7 +96,7 @@ def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
def truncate_content(content: str, max_chars: int) -> str:
"""Truncate the middle of the observation content if it is too long."""
if len(content) <= max_chars:
if len(content) <= max_chars or max_chars == -1:
return content
# truncate the middle and include a message to the LLM about it

View File

@ -1,189 +1,187 @@
import threading
import json
from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_memory
from openhands.events.stream import EventStream
from openhands.utils.embeddings import (
LLAMA_INDEX_AVAILABLE,
EmbeddingsLoader,
check_llama_index,
)
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.utils.tenacity_stop import stop_if_should_exit
try:
import chromadb
import llama_index.embeddings.openai.base as llama_openai
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
# Conditional imports based on llama_index availability
if LLAMA_INDEX_AVAILABLE:
# TODO: this could be made configurable
num_retries: int = 10
retry_min_wait: int = 3
retry_max_wait: int = 300
# llama-index includes a retry decorator around openai.get_embeddings() function
# it is initialized with hard-coded values and errors
# this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
# this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
if hasattr(llama_openai.get_embeddings, '__wrapped__'):
original_get_embeddings = llama_openai.get_embeddings.__wrapped__
else:
logger.warning('Cannot set custom retry limits.')
num_retries = 1
original_get_embeddings = llama_openai.get_embeddings
def attempt_on_error(retry_state):
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
exc_info=False,
)
return None
@retry(
reraise=True,
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, InternalServerError)
),
after=attempt_on_error,
import chromadb
from llama_index.core import Document
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
def wrapper_get_embeddings(*args, **kwargs):
return original_get_embeddings(*args, **kwargs)
llama_openai.get_embeddings = wrapper_get_embeddings
class EmbeddingsLoader:
"""Loader for embedding model initialization."""
@staticmethod
def get_embedding_model(strategy: str, llm_config: LLMConfig):
supported_ollama_embed_models = [
'llama2',
'mxbai-embed-large',
'nomic-embed-text',
'all-minilm',
'stable-code',
'bge-m3',
'bge-large',
'paraphrase-multilingual',
'snowflake-arctic-embed',
]
if strategy in supported_ollama_embed_models:
from llama_index.embeddings.ollama import OllamaEmbedding
return OllamaEmbedding(
model_name=strategy,
base_url=llm_config.embedding_base_url,
ollama_additional_kwargs={'mirostat': 0},
)
elif strategy == 'openai':
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
return AzureOpenAIEmbedding(
model='text-embedding-ada-002',
deployment_name=llm_config.embedding_deployment_name,
api_key=llm_config.api_key,
azure_endpoint=llm_config.base_url,
api_version=llm_config.api_version,
)
elif (strategy is not None) and (strategy.lower() == 'none'):
# TODO: this works but is not elegant enough. The incentive is when
# an agent using embeddings is not used, there is no reason we need to
# initialize an embedding model
return None
else:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class LongTermMemory:
"""Handles storing information for the agent to access later, using chromadb."""
def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
if not LLAMA_INDEX_AVAILABLE:
raise ImportError(
'llama_index and its dependencies are not installed. '
'To use LongTermMemory, please run: poetry install --with llama-index'
)
event_stream: EventStream
db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
def __init__(
self,
llm_config: LLMConfig,
agent_config: AgentConfig,
event_stream: EventStream,
):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
check_llama_index()
# initialize the chromadb client
db = chromadb.PersistentClient(
path=f'./cache/sessions/{event_stream.sid}/memory',
# FIXME anonymized_telemetry=False,
)
self.collection = db.get_or_create_collection(name='memories')
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# embedding model
embedding_strategy = llm_config.embedding_model
embed_model = EmbeddingsLoader.get_embedding_model(
self.embed_model = EmbeddingsLoader.get_embedding_model(
embedding_strategy, llm_config
)
self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
self.sema = threading.Semaphore(value=memory_max_threads)
self.thought_idx = 0
self._add_threads: list[threading.Thread] = []
def add_event(self, event: dict):
# instantiate the index
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
self.thought_idx = 0
# initialize the event stream
self.event_stream = event_stream
# max of threads to run the pipeline
self.memory_max_threads = agent_config.memory_max_threads
def add_event(self, event: Event):
"""Adds a new event to the long term memory with a unique id.
Parameters:
- event (dict): The new event to be added to memory
- event: The new event to be added to memory
"""
id = ''
t = ''
if 'action' in event:
t = 'action'
id = event['action']
elif 'observation' in event:
t = 'observation'
id = event['observation']
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
return
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event),
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': t,
'id': id,
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
self.thought_idx += 1
logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
thread = threading.Thread(target=self._add_doc, args=(doc,))
self._add_threads.append(thread)
thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
self._add_document(document=doc)
def _add_doc(self, doc):
with self.sema:
self.index.insert(doc)
def _add_document(self, document: 'Document'):
"""Inserts a single document into the index."""
self.index.insert_nodes([self._create_node(document)])
def search(self, query: str, k: int = 10):
"""Searches through the current memory using VectorIndexRetriever
def _create_node(self, document: 'Document') -> 'TextNode':
"""Create a TextNode from a Document instance."""
return TextNode(
text=document.text,
doc_id=document.doc_id,
extra_info=document.extra_info,
)
def search(self, query: str, k: int = 10) -> list[str]:
"""Searches through the current memory using VectorIndexRetriever.
Parameters:
- query (str): A query to match search results to
- k (int): Number of top results to return
Returns:
- list[str]: list of top k results found in current memory
- list[str]: List of top k results found in current memory
"""
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=k,
)
results = retriever.retrieve(query)
for result in results:
logger.debug(
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
)
return [r.get_text() for r in results]
def _events_to_docs(self) -> list['Document']:
"""Convert all events from the EventStream to documents for batch insert into the index."""
try:
events = self.event_stream.get_events()
except Exception as e:
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
return []
documents: list[Document] = []
for event in events:
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
documents.append(doc)
self.thought_idx += 1
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
continue
if documents:
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
else:
logger.debug('No valid documents found to insert into the index.')
return documents
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
"""Create nodes from a list of documents."""
return [self._create_node(doc) for doc in documents]

View File

@ -0,0 +1,176 @@
import importlib.util
import os
from joblib import Parallel, delayed
from openhands.core.config import LLMConfig
try:
# check if those we need later are available using importlib
if importlib.util.find_spec('chromadb') is None:
raise ImportError(
'chromadb is not available. Please install it using poetry install --with llama-index'
)
if (
importlib.util.find_spec(
'llama_index.core.indices.vector_store.retrievers.retriever'
)
is None
or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
is None
):
raise ImportError(
'llama_index is not available. Please install it using poetry install --with llama-index'
)
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.schema import TextNode
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
# Define supported embedding models
SUPPORTED_OLLAMA_EMBED_MODELS = [
'llama2',
'mxbai-embed-large',
'nomic-embed-text',
'all-minilm',
'stable-code',
'bge-m3',
'bge-large',
'paraphrase-multilingual',
'snowflake-arctic-embed',
]
def check_llama_index():
"""Utility function to check the availability of llama_index.
Raises:
ImportError: If llama_index is not available.
"""
if not LLAMA_INDEX_AVAILABLE:
raise ImportError(
'llama_index and its dependencies are not installed. '
'To use memory features, please run: poetry install --with llama-index.'
)
class EmbeddingsLoader:
"""Loader for embedding model initialization."""
@staticmethod
def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
"""Initialize and return the appropriate embedding model based on the strategy.
Parameters:
- strategy: The embedding strategy to use.
- llm_config: Configuration for the LLM.
Returns:
- An instance of the selected embedding model or None.
"""
if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
from llama_index.embeddings.ollama import OllamaEmbedding
return OllamaEmbedding(
model_name=strategy,
base_url=llm_config.embedding_base_url,
ollama_additional_kwargs={'mirostat': 0},
)
elif strategy == 'openai':
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
return AzureOpenAIEmbedding(
model='text-embedding-ada-002',
deployment_name=llm_config.embedding_deployment_name,
api_key=llm_config.api_key,
azure_endpoint=llm_config.base_url,
api_version=llm_config.api_version,
)
elif (strategy is not None) and (strategy.lower() == 'none'):
# TODO: this works but is not elegant enough. The incentive is when
# an agent using embeddings is not used, there is no reason we need to
# initialize an embedding model
return None
else:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# initialize the local embedding model
local_embed_model = HuggingFaceEmbedding(
model_name='BAAI/bge-small-en-v1.5'
)
# for local embeddings, we need torch
import torch
# choose the best device
# first determine what is available: CUDA, MPS, or CPU
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = 'mps'
else:
device = 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['PYTORCH_FORCE_CPU'] = (
'1' # try to force CPU to avoid errors
)
# override CUDA availability
torch.cuda.is_available = lambda: False
# disable MPS to avoid errors
if device != 'mps' and hasattr(torch.backends, 'mps'):
torch.backends.mps.is_available = lambda: False
torch.backends.mps.is_built = False
# the device being used
print(f'Using device for embeddings: {device}')
return local_embed_model
# --------------------------------------------------------------------------
# Utility functions to run pipelines, split out for profiling
# --------------------------------------------------------------------------
def run_pipeline(
embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
) -> list['TextNode']:
"""Run a pipeline embedding documents."""
# set up a pipeline with the transformations to make
pipeline = IngestionPipeline(
transformations=[
embed_model,
],
)
# run the pipeline with num_workers
nodes = pipeline.run(
documents=documents, show_progress=True, num_workers=num_workers
)
return nodes
def insert_batch_docs(
index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
) -> list['TextNode']:
"""Run the document indexing in parallel."""
results = Parallel(n_jobs=num_workers, backend='threading')(
delayed(index.insert)(doc) for doc in documents
)
return results

246
tests/unit/test_memory.py Normal file
View File

@ -0,0 +1,246 @@
import json
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.memory.memory import LongTermMemory
from openhands.storage.files import FileStore
@pytest.fixture
def mock_llm_config() -> LLMConfig:
config = MagicMock(spec=LLMConfig)
config.embedding_model = 'test_embedding_model'
config.api_key = 'test_api_key'
config.api_version = 'v1'
return config
@pytest.fixture
def mock_agent_config() -> AgentConfig:
config = AgentConfig(
micro_agent_name='test_micro_agent',
memory_enabled=True,
memory_max_threads=4,
llm_config='test_llm_config',
)
return config
@pytest.fixture
def mock_file_store() -> FileStore:
store = MagicMock(spec=FileStore)
store.sid = 'test_session'
return store
@pytest.fixture
def mock_event_stream(mock_file_store: FileStore) -> EventStream:
with patch('openhands.events.stream.EventStream') as MockEventStream:
instance = MockEventStream.return_value
instance.sid = 'test_session'
instance.get_events = MagicMock()
return instance
@pytest.fixture
def long_term_memory(
mock_llm_config: LLMConfig,
mock_agent_config: AgentConfig,
mock_event_stream: EventStream,
) -> LongTermMemory:
with patch(
'openhands.memory.memory.chromadb.PersistentClient'
) as mock_chroma_client:
mock_collection = MagicMock()
mock_chroma_client.return_value.get_or_create_collection.return_value = (
mock_collection
)
memory = LongTermMemory(
llm_config=mock_llm_config,
agent_config=mock_agent_config,
event_stream=mock_event_stream,
)
memory.collection = mock_collection
return memory
def _create_action_event(action: str) -> Event:
"""Helper function to create an action event."""
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.AGENT
event.action = action
return event
def _create_observation_event(observation: str) -> Event:
"""Helper function to create an observation event."""
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.USER
event.observation = observation
return event
def test_add_event_with_action(long_term_memory: LongTermMemory):
event = _create_action_event('test_action')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'action'
assert kwargs['document'].extra_info['id'] == 'test_action'
def test_add_event_with_observation(long_term_memory: LongTermMemory):
event = _create_observation_event('test_observation')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'observation'
assert kwargs['document'].extra_info['id'] == 'test_observation'
def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
# Creating an event with additional unexpected attributes
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.AGENT
event.action = 'test_action'
event.unexpected_key = 'value'
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
assert long_term_memory.thought_idx == 1
long_term_memory._add_document.assert_called_once()
_, kwargs = long_term_memory._add_document.call_args
assert kwargs['document'].extra_info['type'] == 'action'
assert kwargs['document'].extra_info['id'] == 'test_action'
def test_events_to_docs_no_events(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
mock_event_stream.get_events.side_effect = FileNotFoundError
# convert events to documents
documents = long_term_memory._events_to_docs()
# since get_events raises, documents should be empty
assert len(documents) == 0
# thought_idx remains unchanged
assert long_term_memory.thought_idx == 0
def test_load_events_into_index_with_invalid_json(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
"""Test loading events with malformed event data."""
# Simulate an event that causes event_to_memory to raise a JSONDecodeError
with patch(
'openhands.memory.memory.event_to_memory',
side_effect=json.JSONDecodeError('Expecting value', '', 0),
):
event = _create_action_event('invalid_action')
mock_event_stream.get_events.return_value = [event]
# convert events to documents
documents = long_term_memory._events_to_docs()
# since event_to_memory raises, documents should be empty
assert len(documents) == 0
# thought_idx remains unchanged
assert long_term_memory.thought_idx == 0
def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
event = _create_action_event('test_action')
long_term_memory._add_document = MagicMock()
long_term_memory.add_event(event)
long_term_memory._add_document.assert_called()
_, kwargs = long_term_memory._add_document.call_args
assert 'document' in kwargs
assert (
kwargs['document'].text
== '{"source": "agent", "action": "test_action", "args": {}}'
)
def test_search_returns_correct_results(long_term_memory: LongTermMemory):
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [
MagicMock(get_text=MagicMock(return_value='result1')),
MagicMock(get_text=MagicMock(return_value='result2')),
]
with patch(
'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
):
results = long_term_memory.search(query='test query', k=2)
assert results == ['result1', 'result2']
mock_retriever.retrieve.assert_called_once_with('test query')
def test_search_with_no_results(long_term_memory: LongTermMemory):
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = []
with patch(
'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
):
results = long_term_memory.search(query='no results', k=5)
assert results == []
mock_retriever.retrieve.assert_called_once_with('no results')
def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
event1 = _create_action_event('action1')
event2 = _create_observation_event('observation1')
long_term_memory.add_event(event1)
long_term_memory.add_event(event2)
assert long_term_memory.thought_idx == 2
def test_load_events_batch_insert(
long_term_memory: LongTermMemory, mock_event_stream: EventStream
):
event1 = _create_action_event('action1')
event2 = _create_observation_event('observation1')
event3 = _create_action_event('action2')
mock_event_stream.get_events.return_value = [event1, event2, event3]
# Mock insert_batch_docs
with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
# convert events to documents
documents = long_term_memory._events_to_docs()
# Mock the insert_batch_docs to simulate document insertion
mock_run_docs.return_value = []
# Call insert_batch_docs with the documents
mock_run_docs(
index=long_term_memory.index,
documents=documents,
num_workers=long_term_memory.memory_max_threads,
)
# Assert that insert_batch_docs was called with the correct arguments
mock_run_docs.assert_called_once_with(
index=long_term_memory.index,
documents=documents,
num_workers=long_term_memory.memory_max_threads,
)
# Check if thought_idx was incremented correctly
assert long_term_memory.thought_idx == 3