Feat google cloud storage (#3574)

* Google cloud storage implementation
* Unit test refactor
This commit is contained in:
tofarr
2024-08-26 08:16:49 -06:00
committed by GitHub
parent f1882ba886
commit 8c4c3b18b5
3 changed files with 180 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
from openhands.storage.files import FileStore
from openhands.storage.google_cloud import GoogleCloudFileStore
from openhands.storage.local import LocalFileStore
from openhands.storage.memory import InMemoryFileStore
from openhands.storage.s3 import S3FileStore
@@ -11,4 +12,6 @@ def get_file_store(file_store: str, file_store_path: str | None = None) -> FileS
return LocalFileStore(file_store_path)
elif file_store == 's3':
return S3FileStore()
elif file_store == 'google_cloud':
return GoogleCloudFileStore()
return InMemoryFileStore()

View File

@@ -0,0 +1,59 @@
import os
from typing import List, Optional
from google.cloud import storage
from openhands.storage.files import FileStore
class GoogleCloudFileStore(FileStore):
def __init__(self, bucket_name: Optional[str] = None) -> None:
"""
Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the
environment it will be used for authentication. Otherwise access will be
anonymous.
"""
if bucket_name is None:
bucket_name = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
self.storage_client = storage.Client()
self.bucket = self.storage_client.bucket(bucket_name)
def write(self, path: str, contents: str | bytes) -> None:
blob = self.bucket.blob(path)
with blob.open('w') as f:
f.write(contents)
def read(self, path: str) -> str:
blob = self.bucket.blob(path)
with blob.open('r') as f:
return f.read()
def list(self, path: str) -> List[str]:
if not path or path == '/':
path = ''
elif not path.endswith('/'):
path += '/'
# The delimiter logic screens out directories, so we can't use it. :(
# For example, given a structure:
# foo/bar/zap.txt
# foo/bar/bang.txt
# ping.txt
# prefix=None, delimiter="/" yields ["ping.txt"] # :(
# prefix="foo", delimiter="/" yields [] # :(
blobs = set()
prefix_len = len(path)
for blob in self.bucket.list_blobs(prefix=path):
name = blob.name
if name == path:
continue
try:
index = name.index('/', prefix_len + 1)
if index != prefix_len:
blobs.add(name[: index + 1])
except ValueError:
blobs.add(name)
return list(blobs)
def delete(self, path: str) -> None:
blob = self.bucket.blob(path)
blob.delete()