mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
More Type Safety (#10848)
This commit is contained in:
parent
3366ad9de7
commit
aca568cfbe
@ -7,6 +7,7 @@ warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
no_implicit_optional = True
|
||||
strict_optional = True
|
||||
disable_error_code = type-abstract
|
||||
|
||||
# Exclude third-party runtime directory from type checking
|
||||
exclude = (third_party/|enterprise/)
|
||||
|
||||
@ -31,13 +31,15 @@ def load_experiment_config(conversation_id: str) -> ExperimentConfig | None:
|
||||
class ExperimentManager:
|
||||
@staticmethod
|
||||
def run_conversation_variant_test(
|
||||
user_id: str, conversation_id: str, conversation_settings: ConversationInitData
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
return conversation_settings
|
||||
|
||||
@staticmethod
|
||||
def run_config_variant_test(
|
||||
user_id: str, conversation_id: str, config: OpenHandsConfig
|
||||
user_id: str | None, conversation_id: str, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
exp_config = load_experiment_config(conversation_id)
|
||||
if exp_config and exp_config.config:
|
||||
|
||||
@ -45,7 +45,6 @@ from openhands.server.services.conversation_service import (
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import (
|
||||
ConversationManagerImpl,
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
@ -397,7 +396,7 @@ async def get_prompt(
|
||||
)
|
||||
|
||||
prompt_template = generate_prompt_template(stringified_events)
|
||||
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
|
||||
prompt = await generate_prompt(llm_config, prompt_template, conversation_id)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
@ -413,7 +412,7 @@ def generate_prompt_template(events: str) -> str:
|
||||
return template.render(events=events)
|
||||
|
||||
|
||||
def generate_prompt(
|
||||
async def generate_prompt(
|
||||
llm_config: LLMConfig, prompt_template: str, conversation_id: str
|
||||
) -> str:
|
||||
messages = [
|
||||
@ -427,7 +426,7 @@ def generate_prompt(
|
||||
},
|
||||
]
|
||||
|
||||
raw_prompt = ConversationManagerImpl.request_llm_completion(
|
||||
raw_prompt = await conversation_manager.request_llm_completion(
|
||||
'remember_prompt', conversation_id, llm_config, messages
|
||||
)
|
||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||
|
||||
@ -31,6 +31,14 @@ def import_from(qual_name: str):
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_impl(cls: type[T], impl_name: str | None) -> type[T]:
|
||||
if impl_name is None:
|
||||
return cls
|
||||
impl_class = import_from(impl_name)
|
||||
assert cls == impl_class or issubclass(impl_class, cls)
|
||||
return impl_class
|
||||
|
||||
|
||||
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
|
||||
"""Import and validate a named implementation of a base class.
|
||||
|
||||
@ -62,8 +70,4 @@ def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
|
||||
|
||||
The implementation is cached to avoid repeated imports of the same class.
|
||||
"""
|
||||
if impl_name is None:
|
||||
return cls
|
||||
impl_class = import_from(impl_name)
|
||||
assert cls == impl_class or issubclass(impl_class, cls)
|
||||
return impl_class
|
||||
return _get_impl(cls, impl_name) # type: ignore
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user