From c3ddb26e4302a8f41510e8c060a2c426263625f7 Mon Sep 17 00:00:00 2001 From: tofarr Date: Thu, 5 Dec 2024 12:10:52 -0700 Subject: [PATCH] Feat named imports (#5413) --- openhands/core/config/app_config.py | 3 +++ openhands/runtime/utils/edit.py | 13 +++++++---- openhands/server/app.py | 19 +++++++++++---- openhands/utils/import_utils.py | 22 ++++++++++++++++++ tests/unit/resolver/test_resolve_issues.py | 27 ++++++++++++++-------- tests/unit/test_import_utils.py | 24 +++++++++++++++++++ tests/unit/test_listen.py | 15 +++++++----- 7 files changed, 97 insertions(+), 26 deletions(-) create mode 100644 openhands/utils/import_utils.py create mode 100644 tests/unit/test_import_utils.py diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index a10df018df..9e883860cb 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -67,6 +67,9 @@ class AppConfig: modal_api_token_secret: str = '' disable_color: bool = False jwt_secret: str = uuid.uuid4().hex + attach_session_middleware_class: str = ( + 'openhands.server.middleware.AttachSessionMiddleware' + ) debug: bool = False file_uploads_max_file_size_mb: int = 0 file_uploads_restrict_file_types: bool = False diff --git a/openhands/runtime/utils/edit.py b/openhands/runtime/utils/edit.py index bcb876f865..d95dacb100 100644 --- a/openhands/runtime/utils/edit.py +++ b/openhands/runtime/utils/edit.py @@ -150,11 +150,14 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): ) -> ErrorObservation | None: linter = DefaultLinter() # Copy the original file to a temporary file (with the same ext) and lint it - with tempfile.NamedTemporaryFile( - suffix=suffix, mode='w+', encoding='utf-8' - ) as original_file_copy, tempfile.NamedTemporaryFile( - suffix=suffix, mode='w+', encoding='utf-8' - ) as updated_file_copy: + with ( + tempfile.NamedTemporaryFile( + suffix=suffix, mode='w+', encoding='utf-8' + ) as original_file_copy, + tempfile.NamedTemporaryFile( + suffix=suffix, mode='w+', encoding='utf-8' + ) as updated_file_copy, + ): # Lint the original file original_file_copy.write(old_content) original_file_copy.flush() diff --git a/openhands/server/app.py b/openhands/server/app.py index 33f9766fe6..b168795978 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -21,6 +21,8 @@ from openhands.server.routes.feedback import app as feedback_api_router from openhands.server.routes.files import app as files_api_router from openhands.server.routes.public import app as public_api_router from openhands.server.routes.security import app as security_api_router +from openhands.server.shared import config +from openhands.utils.import_utils import get_impl app = FastAPI() app.add_middleware( @@ -48,9 +50,16 @@ app.include_router(conversation_api_router) app.include_router(security_api_router) app.include_router(feedback_api_router) -app.middleware('http')(AttachSessionMiddleware(app, target_router=files_api_router)) -app.middleware('http')( - AttachSessionMiddleware(app, target_router=conversation_api_router) +AttachSessionMiddlewareImpl = get_impl( + AttachSessionMiddleware, config.attach_session_middleware_class +) +app.middleware('http')(AttachSessionMiddlewareImpl(app, target_router=files_api_router)) +app.middleware('http')( + AttachSessionMiddlewareImpl(app, target_router=conversation_api_router) +) +app.middleware('http')( + AttachSessionMiddlewareImpl(app, target_router=security_api_router) +) +app.middleware('http')( + AttachSessionMiddlewareImpl(app, target_router=feedback_api_router) ) -app.middleware('http')(AttachSessionMiddleware(app, target_router=security_api_router)) -app.middleware('http')(AttachSessionMiddleware(app, target_router=feedback_api_router)) diff --git a/openhands/utils/import_utils.py b/openhands/utils/import_utils.py new file mode 100644 index 0000000000..1a14c119de --- /dev/null +++ b/openhands/utils/import_utils.py @@ -0,0 +1,22 @@ +import importlib +from typing import Type, TypeVar + +T = TypeVar('T') + + +def import_from(qual_name: str): + """Import the value from the qualified name given""" + parts = qual_name.split('.') + module_name = '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + result = getattr(module, parts[-1]) + return result + + +def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]: + """Import a named implementation of the specified class""" + if impl_name is None: + return cls + impl_class = import_from(impl_name) + assert cls == impl_class or issubclass(impl_class, cls) + return impl_class diff --git a/tests/unit/resolver/test_resolve_issues.py b/tests/unit/resolver/test_resolve_issues.py index 8d54adb876..9dd43eff82 100644 --- a/tests/unit/resolver/test_resolve_issues.py +++ b/tests/unit/resolver/test_resolve_issues.py @@ -389,16 +389,23 @@ async def test_process_issue(mock_output_dir, mock_prompt_template): handler_instance.get_instruction.return_value = ('Test instruction', []) handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' - with patch( - 'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime - ), patch( - 'openhands.resolver.resolve_issue.initialize_runtime', - mock_initialize_runtime, - ), patch( - 'openhands.resolver.resolve_issue.run_controller', mock_run_controller - ), patch( - 'openhands.resolver.resolve_issue.complete_runtime', mock_complete_runtime - ), patch('openhands.resolver.resolve_issue.logger'): + with ( + patch( + 'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime + ), + patch( + 'openhands.resolver.resolve_issue.initialize_runtime', + mock_initialize_runtime, + ), + patch( + 'openhands.resolver.resolve_issue.run_controller', mock_run_controller + ), + patch( + 'openhands.resolver.resolve_issue.complete_runtime', + mock_complete_runtime, + ), + patch('openhands.resolver.resolve_issue.logger'), + ): # Call the function result = await process_issue( issue, diff --git a/tests/unit/test_import_utils.py b/tests/unit/test_import_utils.py new file mode 100644 index 0000000000..876c37d27b --- /dev/null +++ b/tests/unit/test_import_utils.py @@ -0,0 +1,24 @@ +from abc import abstractmethod +from dataclasses import dataclass + +from openhands.utils.import_utils import get_impl + + +class Shape: + @abstractmethod + def get_area(self): + """Get the area of this shape""" + + +@dataclass +class Square(Shape): + length: float + + def get_area(self): + return self.length**2 + + +def test_get_impl(): + ShapeImpl = get_impl(Shape, f'{Shape.__module__}.Square') + shape = ShapeImpl(5) + assert shape.get_area() == 25 diff --git a/tests/unit/test_listen.py b/tests/unit/test_listen.py index f19be8aedb..c39c656e0a 100644 --- a/tests/unit/test_listen.py +++ b/tests/unit/test_listen.py @@ -16,8 +16,9 @@ class MockStaticFiles: # Patch necessary components before importing from listen -with patch('openhands.server.session.SessionManager', MockSessionManager), patch( - 'fastapi.staticfiles.StaticFiles', MockStaticFiles +with ( + patch('openhands.server.session.SessionManager', MockSessionManager), + patch('fastapi.staticfiles.StaticFiles', MockStaticFiles), ): from openhands.server.file_config import ( is_extension_allowed, @@ -53,8 +54,9 @@ def test_load_file_upload_config_invalid_max_size(): def test_is_extension_allowed(): - with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch( - 'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf'] + with ( + patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), + patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']), ): assert is_extension_allowed('file.txt') assert is_extension_allowed('file.pdf') @@ -71,8 +73,9 @@ def test_is_extension_allowed_no_restrictions(): def test_is_extension_allowed_wildcard(): - with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch( - 'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*'] + with ( + patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), + patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']), ): assert is_extension_allowed('file.txt') assert is_extension_allowed('file.pdf')