mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix: use atomic write in LocalFileStore to prevent race conditions (#13480)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: OpenHands Bot <contact@all-hands.dev>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.files import FileStore
|
||||
@@ -23,8 +24,20 @@ class LocalFileStore(FileStore):
|
||||
full_path = self.get_full_path(path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
mode = 'w' if isinstance(contents, str) else 'wb'
|
||||
with open(full_path, mode) as f:
|
||||
f.write(contents)
|
||||
|
||||
# Use atomic write: write to temp file, then rename
|
||||
# This prevents race conditions where concurrent writes could corrupt the file
|
||||
temp_path = f'{full_path}.tmp.{os.getpid()}.{threading.get_ident()}'
|
||||
try:
|
||||
with open(temp_path, mode) as f:
|
||||
f.write(contents)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(temp_path, full_path)
|
||||
except Exception:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
full_path = self.get_full_path(path)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from io import BytesIO, StringIO
|
||||
@@ -122,6 +123,57 @@ class TestLocalFileStore(TestCase, _StorageTest):
|
||||
f'Failed to remove temporary directory {self.temp_dir}: {e}'
|
||||
)
|
||||
|
||||
def test_concurrent_writes_no_corruption(self):
|
||||
"""Test that concurrent writes don't corrupt file content.
|
||||
|
||||
This test verifies the atomic write fix by having 9 threads write
|
||||
progressively shorter strings to the same file simultaneously.
|
||||
Without atomic writes, a shorter write following a longer write
|
||||
could result in corrupted content (e.g., "123" followed by garbage
|
||||
from the previous longer write).
|
||||
|
||||
The final content must be exactly one of the valid strings written,
|
||||
with no trailing garbage from other writes.
|
||||
"""
|
||||
filename = 'concurrent_test.txt'
|
||||
# Strings from longest to shortest: "123456789", "12345678", ..., "1"
|
||||
valid_contents = ['123456789'[:i] for i in range(9, 0, -1)]
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(len(valid_contents))
|
||||
|
||||
def write_content(content: str):
|
||||
try:
|
||||
# Wait for all threads to be ready before writing
|
||||
barrier.wait()
|
||||
self.store.write(filename, content)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Start all threads
|
||||
threads = [
|
||||
threading.Thread(target=write_content, args=(content,))
|
||||
for content in valid_contents
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check for errors during writes
|
||||
self.assertEqual(
|
||||
errors, [], f'Errors occurred during concurrent writes: {errors}'
|
||||
)
|
||||
|
||||
# Read final content and verify it's one of the valid strings
|
||||
final_content = self.store.read(filename)
|
||||
self.assertIn(
|
||||
final_content,
|
||||
valid_contents,
|
||||
f"File content '{final_content}' is not one of the valid strings. "
|
||||
f'Length: {len(final_content)}. This indicates file corruption from '
|
||||
f'concurrent writes (e.g., shorter write did not fully replace longer write).',
|
||||
)
|
||||
|
||||
|
||||
class TestInMemoryFileStore(TestCase, _StorageTest):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user