More Type Safety (#10848)

This commit is contained in:
Tim O'Farrell 2025-09-05 11:34:43 -06:00 committed by GitHub
parent 3366ad9de7
commit aca568cfbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 11 deletions

View File

@ -7,6 +7,7 @@ warn_unreachable = True
warn_redundant_casts = True
no_implicit_optional = True
strict_optional = True
disable_error_code = type-abstract
# Exclude third-party runtime directory from type checking
exclude = (third_party/|enterprise/)

View File

@ -31,13 +31,15 @@ def load_experiment_config(conversation_id: str) -> ExperimentConfig | None:
class ExperimentManager:
@staticmethod
def run_conversation_variant_test(
user_id: str, conversation_id: str, conversation_settings: ConversationInitData
user_id: str | None,
conversation_id: str,
conversation_settings: ConversationInitData,
) -> ConversationInitData:
return conversation_settings
@staticmethod
def run_config_variant_test(
user_id: str, conversation_id: str, config: OpenHandsConfig
user_id: str | None, conversation_id: str, config: OpenHandsConfig
) -> OpenHandsConfig:
exp_config = load_experiment_config(conversation_id)
if exp_config and exp_config.config:

View File

@ -45,7 +45,6 @@ from openhands.server.services.conversation_service import (
setup_init_conversation_settings,
)
from openhands.server.shared import (
ConversationManagerImpl,
ConversationStoreImpl,
config,
conversation_manager,
@ -397,7 +396,7 @@ async def get_prompt(
)
prompt_template = generate_prompt_template(stringified_events)
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
prompt = await generate_prompt(llm_config, prompt_template, conversation_id)
return JSONResponse(
{
@ -413,7 +412,7 @@ def generate_prompt_template(events: str) -> str:
return template.render(events=events)
def generate_prompt(
async def generate_prompt(
llm_config: LLMConfig, prompt_template: str, conversation_id: str
) -> str:
messages = [
@ -427,7 +426,7 @@ def generate_prompt(
},
]
raw_prompt = ConversationManagerImpl.request_llm_completion(
raw_prompt = await conversation_manager.request_llm_completion(
'remember_prompt', conversation_id, llm_config, messages
)
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)

View File

@ -31,6 +31,14 @@ def import_from(qual_name: str):
@lru_cache()
def _get_impl(cls: type[T], impl_name: str | None) -> type[T]:
if impl_name is None:
return cls
impl_class = import_from(impl_name)
assert cls == impl_class or issubclass(impl_class, cls)
return impl_class
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
"""Import and validate a named implementation of a base class.
@ -62,8 +70,4 @@ def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
The implementation is cached to avoid repeated imports of the same class.
"""
if impl_name is None:
return cls
impl_class = import_from(impl_name)
assert cls == impl_class or issubclass(impl_class, cls)
return impl_class
return _get_impl(cls, impl_name) # type: ignore