mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Add type hints to storage directory (#7110)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -49,12 +49,12 @@ class ConversationStore(ABC):
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationMetadataResultSet:
|
||||
"""Search conversations"""
|
||||
"""Search conversations."""
|
||||
|
||||
async def get_all_metadata(
|
||||
self, conversation_ids: Iterable[str]
|
||||
) -> list[ConversationMetadata]:
|
||||
"""Get metadata for multiple conversations in parallel"""
|
||||
"""Get metadata for multiple conversations in parallel."""
|
||||
return await wait_all([self.get_metadata(cid) for cid in conversation_ids])
|
||||
|
||||
@classmethod
|
||||
@@ -62,4 +62,4 @@ class ConversationStore(ABC):
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
"""Get a store for the user represented by the token given."""
|
||||
|
||||
@@ -1,39 +1,41 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from google.api_core.exceptions import NotFound
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
from google.cloud.storage.blob import Blob # type: ignore
|
||||
from google.cloud.storage.bucket import Bucket # type: ignore
|
||||
from google.cloud.storage.client import Client # type: ignore
|
||||
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class GoogleCloudFileStore(FileStore):
|
||||
def __init__(self, bucket_name: str | None = None) -> None:
|
||||
"""
|
||||
Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the
|
||||
environment it will be used for authentication. Otherwise access will be
|
||||
anonymous.
|
||||
"""Create a new FileStore.
|
||||
|
||||
If GOOGLE_APPLICATION_CREDENTIALS is defined in the environment it will be used
|
||||
for authentication. Otherwise access will be anonymous.
|
||||
"""
|
||||
if bucket_name is None:
|
||||
bucket_name = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket = self.storage_client.bucket(bucket_name)
|
||||
self.storage_client: Client = storage.Client()
|
||||
self.bucket: Bucket = self.storage_client.bucket(bucket_name)
|
||||
|
||||
def write(self, path: str, contents: str | bytes) -> None:
|
||||
blob = self.bucket.blob(path)
|
||||
blob: Blob = self.bucket.blob(path)
|
||||
mode = 'wb' if isinstance(contents, bytes) else 'w'
|
||||
with blob.open(mode) as f:
|
||||
f.write(contents)
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
blob = self.bucket.blob(path)
|
||||
blob: Blob = self.bucket.blob(path)
|
||||
try:
|
||||
with blob.open('r') as f:
|
||||
return str(f.read())
|
||||
except NotFound as err:
|
||||
raise FileNotFoundError(err)
|
||||
|
||||
def list(self, path: str) -> List[str]:
|
||||
def list(self, path: str) -> list[str]:
|
||||
if not path or path == '/':
|
||||
path = ''
|
||||
elif not path.endswith('/'):
|
||||
@@ -45,10 +47,10 @@ class GoogleCloudFileStore(FileStore):
|
||||
# ping.txt
|
||||
# prefix=None, delimiter="/" yields ["ping.txt"] # :(
|
||||
# prefix="foo", delimiter="/" yields [] # :(
|
||||
blobs = set()
|
||||
blobs: set[str] = set()
|
||||
prefix_len = len(path)
|
||||
for blob in self.bucket.list_blobs(prefix=path):
|
||||
name = blob.name
|
||||
name: str = blob.name
|
||||
if name == path:
|
||||
continue
|
||||
try:
|
||||
@@ -72,7 +74,7 @@ class GoogleCloudFileStore(FileStore):
|
||||
|
||||
# Next try to delete item as a file
|
||||
try:
|
||||
blob = self.bucket.blob(path)
|
||||
blob.delete()
|
||||
file_blob: Blob = self.bucket.blob(path)
|
||||
file_blob.delete()
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, TypedDict
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
@@ -6,6 +7,18 @@ import botocore
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class S3ObjectDict(TypedDict):
|
||||
Key: str
|
||||
|
||||
|
||||
class GetObjectOutputDict(TypedDict):
|
||||
Body: Any
|
||||
|
||||
|
||||
class ListObjectsV2OutputDict(TypedDict):
|
||||
Contents: List[S3ObjectDict] | None
|
||||
|
||||
|
||||
class S3FileStore(FileStore):
|
||||
def __init__(self, bucket_name: str | None) -> None:
|
||||
access_key = os.getenv('AWS_ACCESS_KEY_ID')
|
||||
@@ -14,8 +27,8 @@ class S3FileStore(FileStore):
|
||||
endpoint = self._ensure_url_scheme(secure, os.getenv('AWS_S3_ENDPOINT'))
|
||||
if bucket_name is None:
|
||||
bucket_name = os.environ['AWS_S3_BUCKET']
|
||||
self.bucket = bucket_name
|
||||
self.client = boto3.client(
|
||||
self.bucket: str = bucket_name
|
||||
self.client: Any = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
@@ -44,8 +57,10 @@ class S3FileStore(FileStore):
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=path)
|
||||
with response['Body'] as stream:
|
||||
response: GetObjectOutputDict = self.client.get_object(
|
||||
Bucket=self.bucket, Key=path
|
||||
)
|
||||
with response['Body'] as stream: # type: ignore
|
||||
return str(stream.read().decode('utf-8'))
|
||||
except botocore.exceptions.ClientError as e:
|
||||
# Catch all S3-related errors
|
||||
@@ -78,13 +93,15 @@ class S3FileStore(FileStore):
|
||||
# ping.txt
|
||||
# prefix=None, delimiter="/" yields ["ping.txt"] # :(
|
||||
# prefix="foo", delimiter="/" yields [] # :(
|
||||
results = set()
|
||||
results: set[str] = set()
|
||||
prefix_len = len(path)
|
||||
response = self.client.list_objects_v2(Bucket=self.bucket, Prefix=path)
|
||||
response: ListObjectsV2OutputDict = self.client.list_objects_v2(
|
||||
Bucket=self.bucket, Prefix=path
|
||||
)
|
||||
contents = response.get('Contents')
|
||||
if not contents:
|
||||
return []
|
||||
paths = [obj['Key'] for obj in response['Contents']]
|
||||
paths = [obj['Key'] for obj in contents]
|
||||
for sub_path in paths:
|
||||
if sub_path == path:
|
||||
continue
|
||||
|
||||
@@ -7,21 +7,19 @@ from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class SettingsStore(ABC):
|
||||
"""
|
||||
Storage for ConversationInitData. May or may not support multiple users depending on the environment
|
||||
"""
|
||||
"""Storage for ConversationInitData. May or may not support multiple users depending on the environment."""
|
||||
|
||||
@abstractmethod
|
||||
async def load(self) -> Settings | None:
|
||||
"""Load session init data"""
|
||||
"""Load session init data."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(self, settings: Settings) -> None:
|
||||
"""Store session init data"""
|
||||
"""Store session init data."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> SettingsStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
"""Get a store for the user represented by the token given."""
|
||||
|
||||
1434
poetry.lock
generated
1434
poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -18,7 +18,7 @@ datasets = "*"
|
||||
pandas = "*"
|
||||
litellm = "^1.60.0"
|
||||
google-generativeai = "*" # To use litellm with Gemini Pro API
|
||||
google-api-python-client = "*" # For Google Sheets API
|
||||
google-api-python-client = "^2.164.0" # For Google Sheets API
|
||||
google-auth-httplib2 = "*" # For Google Sheets authentication
|
||||
google-auth-oauthlib = "*" # For Google Sheets OAuth
|
||||
termcolor = "*"
|
||||
@@ -100,6 +100,11 @@ reportlab = "*"
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -129,6 +134,11 @@ ignore = ["D1"]
|
||||
convention = "google"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
@@ -147,6 +157,7 @@ browsergym = "0.10.2"
|
||||
browsergym-webarena = "0.10.2"
|
||||
browsergym-miniwob = "0.10.2"
|
||||
browsergym-visualwebarena = "0.10.2"
|
||||
boto3-stubs = {extras = ["s3"], version = "^1.37.19"}
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
|
||||
Reference in New Issue
Block a user