mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Feat named imports (#5413)
This commit is contained in:
parent
3d853f7db3
commit
c3ddb26e43
@ -67,6 +67,9 @@ class AppConfig:
|
||||
modal_api_token_secret: str = ''
|
||||
disable_color: bool = False
|
||||
jwt_secret: str = uuid.uuid4().hex
|
||||
attach_session_middleware_class: str = (
|
||||
'openhands.server.middleware.AttachSessionMiddleware'
|
||||
)
|
||||
debug: bool = False
|
||||
file_uploads_max_file_size_mb: int = 0
|
||||
file_uploads_restrict_file_types: bool = False
|
||||
|
||||
@ -150,11 +150,14 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
) -> ErrorObservation | None:
|
||||
linter = DefaultLinter()
|
||||
# Copy the original file to a temporary file (with the same ext) and lint it
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as original_file_copy, tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as updated_file_copy:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as original_file_copy,
|
||||
tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, mode='w+', encoding='utf-8'
|
||||
) as updated_file_copy,
|
||||
):
|
||||
# Lint the original file
|
||||
original_file_copy.write(old_content)
|
||||
original_file_copy.flush()
|
||||
|
||||
@ -21,6 +21,8 @@ from openhands.server.routes.feedback import app as feedback_api_router
|
||||
from openhands.server.routes.files import app as files_api_router
|
||||
from openhands.server.routes.public import app as public_api_router
|
||||
from openhands.server.routes.security import app as security_api_router
|
||||
from openhands.server.shared import config
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
@ -48,9 +50,16 @@ app.include_router(conversation_api_router)
|
||||
app.include_router(security_api_router)
|
||||
app.include_router(feedback_api_router)
|
||||
|
||||
app.middleware('http')(AttachSessionMiddleware(app, target_router=files_api_router))
|
||||
app.middleware('http')(
|
||||
AttachSessionMiddleware(app, target_router=conversation_api_router)
|
||||
AttachSessionMiddlewareImpl = get_impl(
|
||||
AttachSessionMiddleware, config.attach_session_middleware_class
|
||||
)
|
||||
app.middleware('http')(AttachSessionMiddlewareImpl(app, target_router=files_api_router))
|
||||
app.middleware('http')(
|
||||
AttachSessionMiddlewareImpl(app, target_router=conversation_api_router)
|
||||
)
|
||||
app.middleware('http')(
|
||||
AttachSessionMiddlewareImpl(app, target_router=security_api_router)
|
||||
)
|
||||
app.middleware('http')(
|
||||
AttachSessionMiddlewareImpl(app, target_router=feedback_api_router)
|
||||
)
|
||||
app.middleware('http')(AttachSessionMiddleware(app, target_router=security_api_router))
|
||||
app.middleware('http')(AttachSessionMiddleware(app, target_router=feedback_api_router))
|
||||
|
||||
22
openhands/utils/import_utils.py
Normal file
22
openhands/utils/import_utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
import importlib
|
||||
from typing import Type, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def import_from(qual_name: str):
|
||||
"""Import the value from the qualified name given"""
|
||||
parts = qual_name.split('.')
|
||||
module_name = '.'.join(parts[:-1])
|
||||
module = importlib.import_module(module_name)
|
||||
result = getattr(module, parts[-1])
|
||||
return result
|
||||
|
||||
|
||||
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
|
||||
"""Import a named implementation of the specified class"""
|
||||
if impl_name is None:
|
||||
return cls
|
||||
impl_class = import_from(impl_name)
|
||||
assert cls == impl_class or issubclass(impl_class, cls)
|
||||
return impl_class
|
||||
@ -389,16 +389,23 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
handler_instance.get_instruction.return_value = ('Test instruction', [])
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
|
||||
with patch(
|
||||
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
|
||||
), patch(
|
||||
'openhands.resolver.resolve_issue.initialize_runtime',
|
||||
mock_initialize_runtime,
|
||||
), patch(
|
||||
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
|
||||
), patch(
|
||||
'openhands.resolver.resolve_issue.complete_runtime', mock_complete_runtime
|
||||
), patch('openhands.resolver.resolve_issue.logger'):
|
||||
with (
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.initialize_runtime',
|
||||
mock_initialize_runtime,
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.complete_runtime',
|
||||
mock_complete_runtime,
|
||||
),
|
||||
patch('openhands.resolver.resolve_issue.logger'),
|
||||
):
|
||||
# Call the function
|
||||
result = await process_issue(
|
||||
issue,
|
||||
|
||||
24
tests/unit/test_import_utils.py
Normal file
24
tests/unit/test_import_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class Shape:
|
||||
@abstractmethod
|
||||
def get_area(self):
|
||||
"""Get the area of this shape"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Square(Shape):
|
||||
length: float
|
||||
|
||||
def get_area(self):
|
||||
return self.length**2
|
||||
|
||||
|
||||
def test_get_impl():
|
||||
ShapeImpl = get_impl(Shape, f'{Shape.__module__}.Square')
|
||||
shape = ShapeImpl(5)
|
||||
assert shape.get_area() == 25
|
||||
@ -16,8 +16,9 @@ class MockStaticFiles:
|
||||
|
||||
|
||||
# Patch necessary components before importing from listen
|
||||
with patch('openhands.server.session.SessionManager', MockSessionManager), patch(
|
||||
'fastapi.staticfiles.StaticFiles', MockStaticFiles
|
||||
with (
|
||||
patch('openhands.server.session.SessionManager', MockSessionManager),
|
||||
patch('fastapi.staticfiles.StaticFiles', MockStaticFiles),
|
||||
):
|
||||
from openhands.server.file_config import (
|
||||
is_extension_allowed,
|
||||
@ -53,8 +54,9 @@ def test_load_file_upload_config_invalid_max_size():
|
||||
|
||||
|
||||
def test_is_extension_allowed():
|
||||
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
|
||||
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']
|
||||
with (
|
||||
patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True),
|
||||
patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.txt', '.pdf']),
|
||||
):
|
||||
assert is_extension_allowed('file.txt')
|
||||
assert is_extension_allowed('file.pdf')
|
||||
@ -71,8 +73,9 @@ def test_is_extension_allowed_no_restrictions():
|
||||
|
||||
|
||||
def test_is_extension_allowed_wildcard():
|
||||
with patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True), patch(
|
||||
'openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']
|
||||
with (
|
||||
patch('openhands.server.file_config.RESTRICT_FILE_TYPES', True),
|
||||
patch('openhands.server.file_config.ALLOWED_EXTENSIONS', ['.*']),
|
||||
):
|
||||
assert is_extension_allowed('file.txt')
|
||||
assert is_extension_allowed('file.pdf')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user