diff --git a/openhands/storage/local.py b/openhands/storage/local.py index fcb766c0ef..e646f3137e 100644 --- a/openhands/storage/local.py +++ b/openhands/storage/local.py @@ -1,5 +1,6 @@ import os import shutil +import threading from openhands.core.logger import openhands_logger as logger from openhands.storage.files import FileStore @@ -23,8 +24,20 @@ class LocalFileStore(FileStore): full_path = self.get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) mode = 'w' if isinstance(contents, str) else 'wb' - with open(full_path, mode) as f: - f.write(contents) + + # Use atomic write: write to temp file, then rename + # This prevents race conditions where concurrent writes could corrupt the file + temp_path = f'{full_path}.tmp.{os.getpid()}.{threading.get_ident()}' + try: + with open(temp_path, mode) as f: + f.write(contents) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, full_path) + except Exception: + if os.path.exists(temp_path): + os.remove(temp_path) + raise def read(self, path: str) -> str: full_path = self.get_full_path(path) diff --git a/tests/unit/storage/test_storage.py b/tests/unit/storage/test_storage.py index a78c12df98..5d2508705f 100644 --- a/tests/unit/storage/test_storage.py +++ b/tests/unit/storage/test_storage.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import shutil import tempfile +import threading from abc import ABC from dataclasses import dataclass, field from io import BytesIO, StringIO @@ -122,6 +123,57 @@ class TestLocalFileStore(TestCase, _StorageTest): f'Failed to remove temporary directory {self.temp_dir}: {e}' ) + def test_concurrent_writes_no_corruption(self): + """Test that concurrent writes don't corrupt file content. + + This test verifies the atomic write fix by having 9 threads write + progressively shorter strings to the same file simultaneously. + Without atomic writes, a shorter write following a longer write + could result in corrupted content (e.g., "123" followed by garbage + from the previous longer write). + + The final content must be exactly one of the valid strings written, + with no trailing garbage from other writes. + """ + filename = 'concurrent_test.txt' + # Strings from longest to shortest: "123456789", "12345678", ..., "1" + valid_contents = ['123456789'[:i] for i in range(9, 0, -1)] + errors: list[Exception] = [] + barrier = threading.Barrier(len(valid_contents)) + + def write_content(content: str): + try: + # Wait for all threads to be ready before writing + barrier.wait() + self.store.write(filename, content) + except Exception as e: + errors.append(e) + + # Start all threads + threads = [ + threading.Thread(target=write_content, args=(content,)) + for content in valid_contents + ] + for t in threads: + t.start() + for t in threads: + t.join() + + # Check for errors during writes + self.assertEqual( + errors, [], f'Errors occurred during concurrent writes: {errors}' + ) + + # Read final content and verify it's one of the valid strings + final_content = self.store.read(filename) + self.assertIn( + final_content, + valid_contents, + f"File content '{final_content}' is not one of the valid strings. " + f'Length: {len(final_content)}. This indicates file corruption from ' + f'concurrent writes (e.g., shorter write did not fully replace longer write).', + ) + class TestInMemoryFileStore(TestCase, _StorageTest): def setUp(self):