Feat named imports (#5413)

This commit is contained in:
tofarr 2024-12-05 12:10:52 -07:00 committed by GitHub
parent 3d853f7db3
commit c3ddb26e43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 97 additions and 26 deletions

View File

@ -67,6 +67,9 @@ class AppConfig:
modal_api_token_secret: str = ''
disable_color: bool = False
jwt_secret: str = uuid.uuid4().hex
attach_session_middleware_class: str = (
'openhands.server.middleware.AttachSessionMiddleware'
)
debug: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False

View File

@ -150,11 +150,14 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
) -> ErrorObservation | None:
linter = DefaultLinter()
# Copy the original file to a temporary file (with the same ext) and lint it
with tempfile.NamedTemporaryFile(
suffix=suffix, mode='w+', encoding='utf-8'
) as original_file_copy, tempfile.NamedTemporaryFile(
suffix=suffix, mode='w+', encoding='utf-8'
) as updated_file_copy:
with (
tempfile.NamedTemporaryFile(
suffix=suffix, mode='w+', encoding='utf-8'
) as original_file_copy,
tempfile.NamedTemporaryFile(
suffix=suffix, mode='w+', encoding='utf-8'
) as updated_file_copy,
):
# Lint the original file
original_file_copy.write(old_content)
original_file_copy.flush()

View File

@ -21,6 +21,8 @@ from openhands.server.routes.feedback import app as feedback_api_router
from openhands.server.routes.files import app as files_api_router
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
from openhands.server.shared import config
from openhands.utils.import_utils import get_impl
app = FastAPI()
app.add_middleware(
@ -48,9 +50,16 @@ app.include_router(conversation_api_router)
app.include_router(security_api_router)
app.include_router(feedback_api_router)
app.middleware('http')(AttachSessionMiddleware(app, target_router=files_api_router))
app.middleware('http')(
AttachSessionMiddleware(app, target_router=conversation_api_router)
AttachSessionMiddlewareImpl = get_impl(
AttachSessionMiddleware, config.attach_session_middleware_class
)
app.middleware('http')(AttachSessionMiddlewareImpl(app, target_router=files_api_router))
app.middleware('http')(
AttachSessionMiddlewareImpl(app, target_router=conversation_api_router)
)
app.middleware('http')(
AttachSessionMiddlewareImpl(app, target_router=security_api_router)
)
app.middleware('http')(
AttachSessionMiddlewareImpl(app, target_router=feedback_api_router)
)
app.middleware('http')(AttachSessionMiddleware(app, target_router=security_api_router))
app.middleware('http')(AttachSessionMiddleware(app, target_router=feedback_api_router))

View File

@ -0,0 +1,22 @@
import importlib
from typing import Type, TypeVar
T = TypeVar('T')
def import_from(qual_name: str):
"""Import the value from the qualified name given"""
parts = qual_name.split('.')
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
result = getattr(module, parts[-1])
return result
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
"""Import a named implementation of the specified class"""
if impl_name is None:
return cls
impl_class = import_from(impl_name)
assert cls == impl_class or issubclass(impl_class, cls)
return impl_class

View File

@ -389,16 +389,23 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
handler_instance.get_instruction.return_value = ('Test instruction', [])
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.initialize_runtime',
mock_initialize_runtime,
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch(
'openhands.resolver.resolve_issue.complete_runtime', mock_complete_runtime
), patch('openhands.resolver.resolve_issue.logger'):
with (
patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
),
patch(
'openhands.resolver.resolve_issue.initialize_runtime',
mock_initialize_runtime,
),
patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
),
patch(
'openhands.resolver.resolve_issue.complete_runtime',
mock_complete_runtime,
),
patch('openhands.resolver.resolve_issue.logger'),
):
# Call the function
result = await process_issue(
issue,

View File

@ -0,0 +1,24 @@
from abc import abstractmethod
from dataclasses import dataclass
from openhands.utils.import_utils import get_impl
class Shape:
@abstractmethod
def get_area(self):
"""Get the area of this shape"""
@dataclass
class Square(Shape):
length: float
def get_area(self):
return self.length**2
def test_get_impl():
ShapeImpl = get_impl(Shape, f'{Shape.__module__}.Square')
shape = ShapeImpl(5)
assert shape.get_area() == 25

View File

@ -16,8 +16,9 @@ class MockStaticFiles:
# Patch necessary components before importing from listen
with patch('openhands.server.session.SessionManager', MockSessionManager), patch(
'fastapi.staticfiles.StaticFiles', MockStaticFiles
with (
patch('openhands.server.session.SessionManager', MockSessionManager),
patch('fastapi.staticfiles.StaticFiles', MockStaticFiles),
):
from openhands.server.file_config import (
is_extension_allowed,
@ -53,8 +54,9 @@ def test_load_file_upload_config_invalid_max_size():
def test_is_extension_allowed():
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']
with (
patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True),
patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']),
):
assert is_extension_allowed('file.txt')
assert is_extension_allowed('file.pdf')
@ -71,8 +73,9 @@ def test_is_extension_allowed_no_restrictions():
def test_is_extension_allowed_wildcard():
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']
with (
patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True),
patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']),
):
assert is_extension_allowed('file.txt')
assert is_extension_allowed('file.pdf')