diff --git a/.github/workflows/py-unit-tests.yml b/.github/workflows/py-unit-tests.yml index acf8d90e7b..dee4a649e0 100644 --- a/.github/workflows/py-unit-tests.yml +++ b/.github/workflows/py-unit-tests.yml @@ -93,7 +93,7 @@ jobs: id: buildx uses: docker/setup-buildx-action@v3 - name: Run Tests - run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit + run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit --ignore=tests/unit/test_memory.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 env: @@ -125,7 +125,7 @@ jobs: - name: Build Environment run: make build - name: Run Tests - run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit + run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 env: diff --git a/config.template.toml b/config.template.toml index 009ff75e77..7673744bba 100644 --- a/config.template.toml +++ b/config.template.toml @@ -185,7 +185,7 @@ model = "gpt-4o-mini" #memory_enabled = false # Memory maximum threads -#memory_max_threads = 2 +#memory_max_threads = 3 # LLM config group to use #llm_config = 'your-llm-config-group' diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 5c482b1de1..839d09277e 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -16,7 +16,7 @@ class AgentConfig: micro_agent_name: str | None = None memory_enabled: bool = False - memory_max_threads: int = 2 + memory_max_threads: int = 3 llm_config: str | None = None def defaults_to_dict(self) -> dict: diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index b77fcd9395..bf5fb72cee 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -96,7 +96,7 @@ def event_to_memory(event: 'Event', max_message_chars: int) -> dict: def truncate_content(content: str, max_chars: int) -> str: """Truncate the middle of the observation content if it is too long.""" - if len(content) <= max_chars: + if len(content) <= max_chars or max_chars == -1: return content # truncate the middle and include a message to the LLM about it diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py index 881775e7a8..9d83cc9cdc 100644 --- a/openhands/memory/memory.py +++ b/openhands/memory/memory.py @@ -1,189 +1,187 @@ -import threading +import json -from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, +from openhands.core.config import AgentConfig, LLMConfig +from openhands.core.logger import openhands_logger as logger +from openhands.events.event import Event +from openhands.events.serialization.event import event_to_memory +from openhands.events.stream import EventStream +from openhands.utils.embeddings import ( + LLAMA_INDEX_AVAILABLE, + EmbeddingsLoader, + check_llama_index, ) -from openhands.core.config import LLMConfig -from openhands.core.logger import openhands_logger as logger -from openhands.core.utils import json -from openhands.utils.tenacity_stop import stop_if_should_exit - -try: - import chromadb - import llama_index.embeddings.openai.base as llama_openai - from llama_index.core import Document, VectorStoreIndex - from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.vector_stores.chroma import ChromaVectorStore - - LLAMA_INDEX_AVAILABLE = True -except ImportError: - LLAMA_INDEX_AVAILABLE = False - +# Conditional imports based on llama_index availability if LLAMA_INDEX_AVAILABLE: - # TODO: this could be made configurable - num_retries: int = 10 - retry_min_wait: int = 3 - retry_max_wait: int = 300 - - # llama-index includes a retry decorator around openai.get_embeddings() function - # it is initialized with hard-coded values and errors - # this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits - # this block attempts to banish it and replace it with our decorator, to allow users to set their own limits - - if hasattr(llama_openai.get_embeddings, '__wrapped__'): - original_get_embeddings = llama_openai.get_embeddings.__wrapped__ - else: - logger.warning('Cannot set custom retry limits.') - num_retries = 1 - original_get_embeddings = llama_openai.get_embeddings - - def attempt_on_error(retry_state): - logger.error( - f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.', - exc_info=False, - ) - return None - - @retry( - reraise=True, - stop=stop_after_attempt(num_retries) | stop_if_should_exit(), - wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, InternalServerError) - ), - after=attempt_on_error, + import chromadb + from llama_index.core import Document + from llama_index.core.indices.vector_store.base import VectorStoreIndex + from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, ) - def wrapper_get_embeddings(*args, **kwargs): - return original_get_embeddings(*args, **kwargs) - - llama_openai.get_embeddings = wrapper_get_embeddings - - class EmbeddingsLoader: - """Loader for embedding model initialization.""" - - @staticmethod - def get_embedding_model(strategy: str, llm_config: LLMConfig): - supported_ollama_embed_models = [ - 'llama2', - 'mxbai-embed-large', - 'nomic-embed-text', - 'all-minilm', - 'stable-code', - 'bge-m3', - 'bge-large', - 'paraphrase-multilingual', - 'snowflake-arctic-embed', - ] - if strategy in supported_ollama_embed_models: - from llama_index.embeddings.ollama import OllamaEmbedding - - return OllamaEmbedding( - model_name=strategy, - base_url=llm_config.embedding_base_url, - ollama_additional_kwargs={'mirostat': 0}, - ) - elif strategy == 'openai': - from llama_index.embeddings.openai import OpenAIEmbedding - - return OpenAIEmbedding( - model='text-embedding-ada-002', - api_key=llm_config.api_key, - ) - elif strategy == 'azureopenai': - from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding - - return AzureOpenAIEmbedding( - model='text-embedding-ada-002', - deployment_name=llm_config.embedding_deployment_name, - api_key=llm_config.api_key, - azure_endpoint=llm_config.base_url, - api_version=llm_config.api_version, - ) - elif (strategy is not None) and (strategy.lower() == 'none'): - # TODO: this works but is not elegant enough. The incentive is when - # an agent using embeddings is not used, there is no reason we need to - # initialize an embedding model - return None - else: - from llama_index.embeddings.huggingface import HuggingFaceEmbedding - - return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5') + from llama_index.core.schema import TextNode + from llama_index.vector_stores.chroma import ChromaVectorStore class LongTermMemory: """Handles storing information for the agent to access later, using chromadb.""" - def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1): - """Initialize the chromadb and set up ChromaVectorStore for later use.""" - if not LLAMA_INDEX_AVAILABLE: - raise ImportError( - 'llama_index and its dependencies are not installed. ' - 'To use LongTermMemory, please run: poetry install --with llama-index' - ) + event_stream: EventStream - db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False)) + def __init__( + self, + llm_config: LLMConfig, + agent_config: AgentConfig, + event_stream: EventStream, + ): + """Initialize the chromadb and set up ChromaVectorStore for later use.""" + + check_llama_index() + + # initialize the chromadb client + db = chromadb.PersistentClient( + path=f'./cache/sessions/{event_stream.sid}/memory', + # FIXME anonymized_telemetry=False, + ) self.collection = db.get_or_create_collection(name='memories') vector_store = ChromaVectorStore(chroma_collection=self.collection) + + # embedding model embedding_strategy = llm_config.embedding_model - embed_model = EmbeddingsLoader.get_embedding_model( + self.embed_model = EmbeddingsLoader.get_embedding_model( embedding_strategy, llm_config ) - self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model) - self.sema = threading.Semaphore(value=memory_max_threads) - self.thought_idx = 0 - self._add_threads: list[threading.Thread] = [] - def add_event(self, event: dict): + # instantiate the index + self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model) + self.thought_idx = 0 + + # initialize the event stream + self.event_stream = event_stream + + # max of threads to run the pipeline + self.memory_max_threads = agent_config.memory_max_threads + + def add_event(self, event: Event): """Adds a new event to the long term memory with a unique id. Parameters: - - event (dict): The new event to be added to memory + - event: The new event to be added to memory """ - id = '' - t = '' - if 'action' in event: - t = 'action' - id = event['action'] - elif 'observation' in event: - t = 'observation' - id = event['observation'] + try: + # convert the event to a memory-friendly format, and don't truncate + event_data = event_to_memory(event, -1) + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f'Failed to process event: {e}') + return + + # determine the event type and ID + event_type = '' + event_id = '' + if 'action' in event_data: + event_type = 'action' + event_id = event_data['action'] + elif 'observation' in event_data: + event_type = 'observation' + event_id = event_data['observation'] + + # create a Document instance for the event doc = Document( - text=json.dumps(event), + text=json.dumps(event_data), doc_id=str(self.thought_idx), extra_info={ - 'type': t, - 'id': id, + 'type': event_type, + 'id': event_id, 'idx': self.thought_idx, }, ) self.thought_idx += 1 - logger.debug('Adding %s event to memory: %d', t, self.thought_idx) - thread = threading.Thread(target=self._add_doc, args=(doc,)) - self._add_threads.append(thread) - thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert + logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx) + self._add_document(document=doc) - def _add_doc(self, doc): - with self.sema: - self.index.insert(doc) + def _add_document(self, document: 'Document'): + """Inserts a single document into the index.""" + self.index.insert_nodes([self._create_node(document)]) - def search(self, query: str, k: int = 10): - """Searches through the current memory using VectorIndexRetriever + def _create_node(self, document: 'Document') -> 'TextNode': + """Create a TextNode from a Document instance.""" + return TextNode( + text=document.text, + doc_id=document.doc_id, + extra_info=document.extra_info, + ) + + def search(self, query: str, k: int = 10) -> list[str]: + """Searches through the current memory using VectorIndexRetriever. Parameters: - query (str): A query to match search results to - k (int): Number of top results to return Returns: - - list[str]: list of top k results found in current memory + - list[str]: List of top k results found in current memory """ retriever = VectorIndexRetriever( index=self.index, similarity_top_k=k, ) results = retriever.retrieve(query) + + for result in results: + logger.debug( + f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}' + ) + return [r.get_text() for r in results] + + def _events_to_docs(self) -> list['Document']: + """Convert all events from the EventStream to documents for batch insert into the index.""" + try: + events = self.event_stream.get_events() + except Exception as e: + logger.debug(f'No events found for session {self.event_stream.sid}: {e}') + return [] + + documents: list[Document] = [] + + for event in events: + try: + # convert the event to a memory-friendly format, and don't truncate + event_data = event_to_memory(event, -1) + + # determine the event type and ID + event_type = '' + event_id = '' + if 'action' in event_data: + event_type = 'action' + event_id = event_data['action'] + elif 'observation' in event_data: + event_type = 'observation' + event_id = event_data['observation'] + + # create a Document instance for the event + doc = Document( + text=json.dumps(event_data), + doc_id=str(self.thought_idx), + extra_info={ + 'type': event_type, + 'id': event_id, + 'idx': self.thought_idx, + }, + ) + documents.append(doc) + self.thought_idx += 1 + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f'Failed to process event: {e}') + continue + + if documents: + logger.debug(f'Batch inserting {len(documents)} documents into the index.') + else: + logger.debug('No valid documents found to insert into the index.') + + return documents + + def create_nodes(self, documents: list['Document']) -> list['TextNode']: + """Create nodes from a list of documents.""" + return [self._create_node(doc) for doc in documents] diff --git a/openhands/utils/embeddings.py b/openhands/utils/embeddings.py new file mode 100644 index 0000000000..07ee2d27f5 --- /dev/null +++ b/openhands/utils/embeddings.py @@ -0,0 +1,176 @@ +import importlib.util +import os + +from joblib import Parallel, delayed + +from openhands.core.config import LLMConfig + +try: + # check if those we need later are available using importlib + if importlib.util.find_spec('chromadb') is None: + raise ImportError( + 'chromadb is not available. Please install it using poetry install --with llama-index' + ) + + if ( + importlib.util.find_spec( + 'llama_index.core.indices.vector_store.retrievers.retriever' + ) + is None + or importlib.util.find_spec('llama_index.core.indices.vector_store.base') + is None + ): + raise ImportError( + 'llama_index is not available. Please install it using poetry install --with llama-index' + ) + + from llama_index.core import Document, VectorStoreIndex + from llama_index.core.base.embeddings.base import BaseEmbedding + from llama_index.core.ingestion import IngestionPipeline + from llama_index.core.schema import TextNode + + LLAMA_INDEX_AVAILABLE = True + +except ImportError: + LLAMA_INDEX_AVAILABLE = False + +# Define supported embedding models +SUPPORTED_OLLAMA_EMBED_MODELS = [ + 'llama2', + 'mxbai-embed-large', + 'nomic-embed-text', + 'all-minilm', + 'stable-code', + 'bge-m3', + 'bge-large', + 'paraphrase-multilingual', + 'snowflake-arctic-embed', +] + + +def check_llama_index(): + """Utility function to check the availability of llama_index. + + Raises: + ImportError: If llama_index is not available. + """ + if not LLAMA_INDEX_AVAILABLE: + raise ImportError( + 'llama_index and its dependencies are not installed. ' + 'To use memory features, please run: poetry install --with llama-index.' + ) + + +class EmbeddingsLoader: + """Loader for embedding model initialization.""" + + @staticmethod + def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding': + """Initialize and return the appropriate embedding model based on the strategy. + + Parameters: + - strategy: The embedding strategy to use. + - llm_config: Configuration for the LLM. + + Returns: + - An instance of the selected embedding model or None. + """ + + if strategy in SUPPORTED_OLLAMA_EMBED_MODELS: + from llama_index.embeddings.ollama import OllamaEmbedding + + return OllamaEmbedding( + model_name=strategy, + base_url=llm_config.embedding_base_url, + ollama_additional_kwargs={'mirostat': 0}, + ) + elif strategy == 'openai': + from llama_index.embeddings.openai import OpenAIEmbedding + + return OpenAIEmbedding( + model='text-embedding-ada-002', + api_key=llm_config.api_key, + ) + elif strategy == 'azureopenai': + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + + return AzureOpenAIEmbedding( + model='text-embedding-ada-002', + deployment_name=llm_config.embedding_deployment_name, + api_key=llm_config.api_key, + azure_endpoint=llm_config.base_url, + api_version=llm_config.api_version, + ) + elif (strategy is not None) and (strategy.lower() == 'none'): + # TODO: this works but is not elegant enough. The incentive is when + # an agent using embeddings is not used, there is no reason we need to + # initialize an embedding model + return None + else: + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + + # initialize the local embedding model + local_embed_model = HuggingFaceEmbedding( + model_name='BAAI/bge-small-en-v1.5' + ) + + # for local embeddings, we need torch + import torch + + # choose the best device + # first determine what is available: CUDA, MPS, or CPU + if torch.cuda.is_available(): + device = 'cuda' + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = 'mps' + else: + device = 'cpu' + os.environ['CUDA_VISIBLE_DEVICES'] = '' + os.environ['PYTORCH_FORCE_CPU'] = ( + '1' # try to force CPU to avoid errors + ) + + # override CUDA availability + torch.cuda.is_available = lambda: False + + # disable MPS to avoid errors + if device != 'mps' and hasattr(torch.backends, 'mps'): + torch.backends.mps.is_available = lambda: False + torch.backends.mps.is_built = False + + # the device being used + print(f'Using device for embeddings: {device}') + + return local_embed_model + + +# -------------------------------------------------------------------------- +# Utility functions to run pipelines, split out for profiling +# -------------------------------------------------------------------------- +def run_pipeline( + embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int +) -> list['TextNode']: + """Run a pipeline embedding documents.""" + + # set up a pipeline with the transformations to make + pipeline = IngestionPipeline( + transformations=[ + embed_model, + ], + ) + + # run the pipeline with num_workers + nodes = pipeline.run( + documents=documents, show_progress=True, num_workers=num_workers + ) + return nodes + + +def insert_batch_docs( + index: 'VectorStoreIndex', documents: list['Document'], num_workers: int +) -> list['TextNode']: + """Run the document indexing in parallel.""" + results = Parallel(n_jobs=num_workers, backend='threading')( + delayed(index.insert)(doc) for doc in documents + ) + return results diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py new file mode 100644 index 0000000000..49c52bb268 --- /dev/null +++ b/tests/unit/test_memory.py @@ -0,0 +1,246 @@ +import json +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.core.config import AgentConfig, LLMConfig +from openhands.events.event import Event, EventSource +from openhands.events.stream import EventStream +from openhands.memory.memory import LongTermMemory +from openhands.storage.files import FileStore + + +@pytest.fixture +def mock_llm_config() -> LLMConfig: + config = MagicMock(spec=LLMConfig) + config.embedding_model = 'test_embedding_model' + config.api_key = 'test_api_key' + config.api_version = 'v1' + return config + + +@pytest.fixture +def mock_agent_config() -> AgentConfig: + config = AgentConfig( + micro_agent_name='test_micro_agent', + memory_enabled=True, + memory_max_threads=4, + llm_config='test_llm_config', + ) + return config + + +@pytest.fixture +def mock_file_store() -> FileStore: + store = MagicMock(spec=FileStore) + store.sid = 'test_session' + return store + + +@pytest.fixture +def mock_event_stream(mock_file_store: FileStore) -> EventStream: + with patch('openhands.events.stream.EventStream') as MockEventStream: + instance = MockEventStream.return_value + instance.sid = 'test_session' + instance.get_events = MagicMock() + return instance + + +@pytest.fixture +def long_term_memory( + mock_llm_config: LLMConfig, + mock_agent_config: AgentConfig, + mock_event_stream: EventStream, +) -> LongTermMemory: + with patch( + 'openhands.memory.memory.chromadb.PersistentClient' + ) as mock_chroma_client: + mock_collection = MagicMock() + mock_chroma_client.return_value.get_or_create_collection.return_value = ( + mock_collection + ) + memory = LongTermMemory( + llm_config=mock_llm_config, + agent_config=mock_agent_config, + event_stream=mock_event_stream, + ) + memory.collection = mock_collection + return memory + + +def _create_action_event(action: str) -> Event: + """Helper function to create an action event.""" + event = Event() + event._id = -1 + event._timestamp = datetime.now(timezone.utc).isoformat() + event._source = EventSource.AGENT + event.action = action + return event + + +def _create_observation_event(observation: str) -> Event: + """Helper function to create an observation event.""" + event = Event() + event._id = -1 + event._timestamp = datetime.now(timezone.utc).isoformat() + event._source = EventSource.USER + event.observation = observation + return event + + +def test_add_event_with_action(long_term_memory: LongTermMemory): + event = _create_action_event('test_action') + long_term_memory._add_document = MagicMock() + long_term_memory.add_event(event) + assert long_term_memory.thought_idx == 1 + long_term_memory._add_document.assert_called_once() + _, kwargs = long_term_memory._add_document.call_args + assert kwargs['document'].extra_info['type'] == 'action' + assert kwargs['document'].extra_info['id'] == 'test_action' + + +def test_add_event_with_observation(long_term_memory: LongTermMemory): + event = _create_observation_event('test_observation') + long_term_memory._add_document = MagicMock() + long_term_memory.add_event(event) + assert long_term_memory.thought_idx == 1 + long_term_memory._add_document.assert_called_once() + _, kwargs = long_term_memory._add_document.call_args + assert kwargs['document'].extra_info['type'] == 'observation' + assert kwargs['document'].extra_info['id'] == 'test_observation' + + +def test_add_event_with_missing_keys(long_term_memory: LongTermMemory): + # Creating an event with additional unexpected attributes + event = Event() + event._id = -1 + event._timestamp = datetime.now(timezone.utc).isoformat() + event._source = EventSource.AGENT + event.action = 'test_action' + event.unexpected_key = 'value' + + long_term_memory._add_document = MagicMock() + long_term_memory.add_event(event) + assert long_term_memory.thought_idx == 1 + long_term_memory._add_document.assert_called_once() + _, kwargs = long_term_memory._add_document.call_args + assert kwargs['document'].extra_info['type'] == 'action' + assert kwargs['document'].extra_info['id'] == 'test_action' + + +def test_events_to_docs_no_events( + long_term_memory: LongTermMemory, mock_event_stream: EventStream +): + mock_event_stream.get_events.side_effect = FileNotFoundError + + # convert events to documents + documents = long_term_memory._events_to_docs() + + # since get_events raises, documents should be empty + assert len(documents) == 0 + + # thought_idx remains unchanged + assert long_term_memory.thought_idx == 0 + + +def test_load_events_into_index_with_invalid_json( + long_term_memory: LongTermMemory, mock_event_stream: EventStream +): + """Test loading events with malformed event data.""" + # Simulate an event that causes event_to_memory to raise a JSONDecodeError + with patch( + 'openhands.memory.memory.event_to_memory', + side_effect=json.JSONDecodeError('Expecting value', '', 0), + ): + event = _create_action_event('invalid_action') + mock_event_stream.get_events.return_value = [event] + + # convert events to documents + documents = long_term_memory._events_to_docs() + + # since event_to_memory raises, documents should be empty + assert len(documents) == 0 + + # thought_idx remains unchanged + assert long_term_memory.thought_idx == 0 + + +def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory): + event = _create_action_event('test_action') + long_term_memory._add_document = MagicMock() + long_term_memory.add_event(event) + long_term_memory._add_document.assert_called() + _, kwargs = long_term_memory._add_document.call_args + assert 'document' in kwargs + assert ( + kwargs['document'].text + == '{"source": "agent", "action": "test_action", "args": {}}' + ) + + +def test_search_returns_correct_results(long_term_memory: LongTermMemory): + mock_retriever = MagicMock() + mock_retriever.retrieve.return_value = [ + MagicMock(get_text=MagicMock(return_value='result1')), + MagicMock(get_text=MagicMock(return_value='result2')), + ] + with patch( + 'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever + ): + results = long_term_memory.search(query='test query', k=2) + assert results == ['result1', 'result2'] + mock_retriever.retrieve.assert_called_once_with('test query') + + +def test_search_with_no_results(long_term_memory: LongTermMemory): + mock_retriever = MagicMock() + mock_retriever.retrieve.return_value = [] + with patch( + 'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever + ): + results = long_term_memory.search(query='no results', k=5) + assert results == [] + mock_retriever.retrieve.assert_called_once_with('no results') + + +def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory): + event1 = _create_action_event('action1') + event2 = _create_observation_event('observation1') + long_term_memory.add_event(event1) + long_term_memory.add_event(event2) + assert long_term_memory.thought_idx == 2 + + +def test_load_events_batch_insert( + long_term_memory: LongTermMemory, mock_event_stream: EventStream +): + event1 = _create_action_event('action1') + event2 = _create_observation_event('observation1') + event3 = _create_action_event('action2') + mock_event_stream.get_events.return_value = [event1, event2, event3] + + # Mock insert_batch_docs + with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs: + # convert events to documents + documents = long_term_memory._events_to_docs() + + # Mock the insert_batch_docs to simulate document insertion + mock_run_docs.return_value = [] + + # Call insert_batch_docs with the documents + mock_run_docs( + index=long_term_memory.index, + documents=documents, + num_workers=long_term_memory.memory_max_threads, + ) + + # Assert that insert_batch_docs was called with the correct arguments + mock_run_docs.assert_called_once_with( + index=long_term_memory.index, + documents=documents, + num_workers=long_term_memory.memory_max_threads, + ) + + # Check if thought_idx was incremented correctly + assert long_term_memory.thought_idx == 3