From 7d356cad477b8d764a7eb2596272d51a2c8acb13 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 8 May 2025 21:25:02 -0400 Subject: [PATCH] Add type annotations to CLI directory (#8291) Co-authored-by: openhands --- openhands/cli/commands.py | 9 ++-- openhands/cli/main.py | 6 +-- openhands/cli/settings.py | 44 +++++++-------- openhands/cli/tui.py | 60 +++++++++++---------- openhands/cli/utils.py | 110 ++++++++++++++++++++++++++++++-------- 5 files changed, 151 insertions(+), 78 deletions(-) diff --git a/openhands/cli/commands.py b/openhands/cli/commands.py index e0f52b703a..e9f6a4de25 100644 --- a/openhands/cli/commands.py +++ b/openhands/cli/commands.py @@ -100,7 +100,7 @@ def handle_exit_command( return close_repl -def handle_help_command(): +def handle_help_command() -> None: display_help() @@ -135,7 +135,7 @@ async def handle_init_command( return close_repl, reload_microagents -def handle_status_command(usage_metrics: UsageMetrics, sid: str): +def handle_status_command(usage_metrics: UsageMetrics, sid: str) -> None: display_status(usage_metrics, sid) @@ -168,7 +168,7 @@ def handle_new_command( async def handle_settings_command( config: AppConfig, settings_store: FileSettingsStore, -): +) -> None: display_settings(config) modify_settings = cli_confirm( '\nWhich settings would you like to modify?', @@ -213,6 +213,7 @@ async def init_repository(current_dir: str) -> bool: if repo_file_path.exists(): try: + # Path.exists() ensures repo_file_path is not None, so we can safely pass it to read_file content = await asyncio.get_event_loop().run_in_executor( None, read_file, repo_file_path ) @@ -263,7 +264,7 @@ async def init_repository(current_dir: str) -> bool: return init_repo -def check_folder_security_agreement(config: AppConfig, current_dir): +def check_folder_security_agreement(config: AppConfig, current_dir: str) -> bool: # Directories trusted by user for the CLI to use as workspace # Config from ~/.openhands/config.toml overrides the app config diff --git a/openhands/cli/main.py b/openhands/cli/main.py index bc51dbbc0c..d9cbf79256 100644 --- a/openhands/cli/main.py +++ b/openhands/cli/main.py @@ -68,7 +68,7 @@ async def cleanup_session( agent: Agent, runtime: Runtime, controller: AgentController, -): +) -> None: """Clean up all resources from the current session.""" try: # Cancel all running tasks except the current one @@ -126,7 +126,7 @@ async def run_session( usage_metrics = UsageMetrics() - async def prompt_for_next_task(agent_state: str): + async def prompt_for_next_task(agent_state: str) -> None: nonlocal reload_microagents, new_session_requested while True: next_message = await read_prompt_input( @@ -271,7 +271,7 @@ async def run_session( return new_session_requested -async def main(loop: asyncio.AbstractEventLoop): +async def main(loop: asyncio.AbstractEventLoop) -> None: """Runs the agent in CLI mode.""" args = parse_arguments() diff --git a/openhands/cli/settings.py b/openhands/cli/settings.py index eeee3cc74f..68f7fdb165 100644 --- a/openhands/cli/settings.py +++ b/openhands/cli/settings.py @@ -29,7 +29,7 @@ from openhands.storage.settings.file_settings_store import FileSettingsStore from openhands.utils.llm import get_supported_llm_models -def display_settings(config: AppConfig): +def display_settings(config: AppConfig) -> None: llm_config = config.get_llm_config() advanced_llm_settings = True if llm_config.base_url else False @@ -108,8 +108,8 @@ async def get_validated_input( prompt_text: str, completer=None, validator=None, - error_message='Input cannot be empty', -): + error_message: str = 'Input cannot be empty', +) -> str: session.completer = completer value = None @@ -146,7 +146,7 @@ def save_settings_confirmation() -> bool: async def modify_llm_settings_basic( config: AppConfig, settings_store: FileSettingsStore -): +) -> None: model_list = get_supported_llm_models(config) organized_models = organize_models_and_providers(model_list) @@ -171,20 +171,24 @@ async def modify_llm_settings_basic( error_message='Invalid provider selected', ) - model_list = organized_models[provider]['models'] + provider_models = organized_models[provider]['models'] if provider == 'openai': - model_list = [m for m in model_list if m not in VERIFIED_OPENAI_MODELS] - model_list = VERIFIED_OPENAI_MODELS + model_list + provider_models = [ + m for m in provider_models if m not in VERIFIED_OPENAI_MODELS + ] + provider_models = VERIFIED_OPENAI_MODELS + provider_models if provider == 'anthropic': - model_list = [m for m in model_list if m not in VERIFIED_ANTHROPIC_MODELS] - model_list = VERIFIED_ANTHROPIC_MODELS + model_list + provider_models = [ + m for m in provider_models if m not in VERIFIED_ANTHROPIC_MODELS + ] + provider_models = VERIFIED_ANTHROPIC_MODELS + provider_models - model_completer = FuzzyWordCompleter(model_list) + model_completer = FuzzyWordCompleter(provider_models) model = await get_validated_input( session, '(Step 2/3) Select LLM Model (TAB for options, CTRL-c to cancel): ', completer=model_completer, - validator=lambda x: x in organized_models[provider]['models'], + validator=lambda x: x in provider_models, error_message=f'Invalid model selected for provider {provider}', ) @@ -201,10 +205,8 @@ async def modify_llm_settings_basic( ): return # Return on exception - # TODO: check for empty string inputs? - # Handle case where a prompt might return None unexpectedly - if provider is None or model is None or api_key is None: - return + # The try-except block above ensures we either have valid inputs or we've already returned + # No need to check for None values here save_settings = save_settings_confirmation() @@ -212,7 +214,7 @@ async def modify_llm_settings_basic( return llm_config = config.get_llm_config() - llm_config.model = provider + organized_models[provider]['separator'] + model + llm_config.model = f'{provider}{organized_models[provider]["separator"]}{model}' llm_config.api_key = SecretStr(api_key) llm_config.base_url = None config.set_llm_config(llm_config) @@ -232,7 +234,7 @@ async def modify_llm_settings_basic( if not settings: settings = Settings() - settings.llm_model = provider + organized_models[provider]['separator'] + model + settings.llm_model = f'{provider}{organized_models[provider]["separator"]}{model}' settings.llm_api_key = SecretStr(api_key) settings.llm_base_url = None settings.agent = OH_DEFAULT_AGENT @@ -244,7 +246,7 @@ async def modify_llm_settings_basic( async def modify_llm_settings_advanced( config: AppConfig, settings_store: FileSettingsStore -): +) -> None: session = PromptSession(key_bindings=kb_cancel()) custom_model = None @@ -304,10 +306,8 @@ async def modify_llm_settings_advanced( ): return # Return on exception - # TODO: check for empty string inputs? - # Handle case where a prompt might return None unexpectedly - if custom_model is None or base_url is None or api_key is None or agent is None: - return + # The try-except block above ensures we either have valid inputs or we've already returned + # No need to check for None values here save_settings = save_settings_confirmation() diff --git a/openhands/cli/tui.py b/openhands/cli/tui.py index ae8981610a..e84676f83f 100644 --- a/openhands/cli/tui.py +++ b/openhands/cli/tui.py @@ -6,10 +6,12 @@ import asyncio import sys import threading import time +from typing import Generator from prompt_toolkit import PromptSession, print_formatted_text from prompt_toolkit.application import Application -from prompt_toolkit.completion import Completer, Completion +from prompt_toolkit.completion import CompleteEvent, Completer, Completion +from prompt_toolkit.document import Document from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples from prompt_toolkit.input import create_input from prompt_toolkit.key_binding import KeyBindings @@ -96,7 +98,7 @@ class CustomDiffLexer(Lexer): # CLI initialization and startup display functions -def display_runtime_initialization_message(runtime: str): +def display_runtime_initialization_message(runtime: str) -> None: print_formatted_text('') if runtime == 'local': print_formatted_text(HTML('⚙️ Starting local runtime...')) @@ -105,7 +107,7 @@ def display_runtime_initialization_message(runtime: str): print_formatted_text('') -def display_initialization_animation(text, is_loaded: asyncio.Event): +def display_initialization_animation(text: str, is_loaded: asyncio.Event) -> None: ANIMATION_FRAMES = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] i = 0 @@ -122,7 +124,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event): sys.stdout.flush() -def display_banner(session_id: str): +def display_banner(session_id: str) -> None: print_formatted_text( HTML(r""" ___ _ _ _ @@ -142,7 +144,7 @@ def display_banner(session_id: str): print_formatted_text('') -def display_welcome_message(): +def display_welcome_message() -> None: print_formatted_text( HTML("Let's start building!\n"), style=DEFAULT_STYLE ) @@ -152,7 +154,7 @@ def display_welcome_message(): ) -def display_initial_user_prompt(prompt: str): +def display_initial_user_prompt(prompt: str) -> None: print_formatted_text( FormattedText( [ @@ -187,14 +189,14 @@ def display_event(event: Event, config: AppConfig) -> None: display_agent_state_change_message(event.agent_state) -def display_message(message: str): +def display_message(message: str) -> None: message = message.strip() if message: print_formatted_text(f'\n{message}') -def display_command(event: CmdRunAction): +def display_command(event: CmdRunAction) -> None: if event.confirmation_state == ActionConfirmationStatus.AWAITING_CONFIRMATION: container = Frame( TextArea( @@ -210,7 +212,7 @@ def display_command(event: CmdRunAction): print_container(container) -def display_command_output(output: str): +def display_command_output(output: str) -> None: lines = output.split('\n') formatted_lines = [] for line in lines: @@ -238,7 +240,7 @@ def display_command_output(output: str): print_container(container) -def display_file_edit(event: FileEditObservation): +def display_file_edit(event: FileEditObservation) -> None: container = Frame( TextArea( text=event.visualize_diff(n_context_lines=4), @@ -253,7 +255,7 @@ def display_file_edit(event: FileEditObservation): print_container(container) -def display_file_read(event: FileReadObservation): +def display_file_read(event: FileReadObservation) -> None: content = event.content.replace('\t', ' ') container = Frame( TextArea( @@ -270,7 +272,7 @@ def display_file_read(event: FileReadObservation): # Interactive command output display functions -def display_help(): +def display_help() -> None: # Version header and introduction print_formatted_text( HTML( @@ -314,7 +316,7 @@ def display_help(): ) -def display_usage_metrics(usage_metrics: UsageMetrics): +def display_usage_metrics(usage_metrics: UsageMetrics) -> None: cost_str = f'${usage_metrics.metrics.accumulated_cost:.6f}' input_tokens_str = ( f'{usage_metrics.metrics.accumulated_token_usage.prompt_tokens:,}' @@ -375,7 +377,7 @@ def get_session_duration(session_init_time: float) -> str: return f'{int(hours)}h {int(minutes)}m {int(seconds)}s' -def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str): +def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str) -> None: duration_str = get_session_duration(usage_metrics.session_init_time) print_formatted_text(HTML('Closing current conversation...')) @@ -388,7 +390,7 @@ def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str): print_formatted_text('') -def display_status(usage_metrics: UsageMetrics, session_id: str): +def display_status(usage_metrics: UsageMetrics, session_id: str) -> None: duration_str = get_session_duration(usage_metrics.session_init_time) print_formatted_text('') @@ -398,14 +400,14 @@ def display_status(usage_metrics: UsageMetrics, session_id: str): display_usage_metrics(usage_metrics) -def display_agent_running_message(): +def display_agent_running_message() -> None: print_formatted_text('') print_formatted_text( HTML('Agent running... (Press Ctrl-P to pause)') ) -def display_agent_state_change_message(agent_state: str): +def display_agent_state_change_message(agent_state: str) -> None: if agent_state == AgentState.PAUSED: print_formatted_text('') print_formatted_text( @@ -429,7 +431,9 @@ class CommandCompleter(Completer): super().__init__() self.agent_state = agent_state - def get_completions(self, document, complete_event): + def get_completions( + self, document: Document, complete_event: CompleteEvent + ) -> Generator[Completion, None, None]: text = document.text_before_cursor.lstrip() if text.startswith('/'): available_commands = dict(COMMANDS) @@ -446,11 +450,11 @@ class CommandCompleter(Completer): ) -def create_prompt_session(): +def create_prompt_session() -> PromptSession: return PromptSession(style=DEFAULT_STYLE) -async def read_prompt_input(agent_state: str, multiline=False): +async def read_prompt_input(agent_state: str, multiline: bool = False) -> str: try: prompt_session = create_prompt_session() prompt_session.completer = ( @@ -461,7 +465,7 @@ async def read_prompt_input(agent_state: str, multiline=False): kb = KeyBindings() @kb.add('c-d') - def _(event): + def _(event) -> None: event.current_buffer.validate_and_handle() with patch_stdout(): @@ -511,7 +515,7 @@ async def read_confirmation_input() -> str: async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None: input = create_input() - def keys_ready(): + def keys_ready() -> None: for key_press in input.read_keys(): if ( key_press.key == Keys.ControlP @@ -543,7 +547,7 @@ def cli_confirm( choices = ['Yes', 'No'] selected = [0] # Using list to allow modification in closure - def get_choice_text(): + def get_choice_text() -> list: return [ ('class:question', f'{question}\n\n'), ] + [ @@ -557,15 +561,15 @@ def cli_confirm( kb = KeyBindings() @kb.add('up') - def _(event): + def _(event) -> None: selected[0] = (selected[0] - 1) % len(choices) @kb.add('down') - def _(event): + def _(event) -> None: selected[0] = (selected[0] + 1) % len(choices) @kb.add('enter') - def _(event): + def _(event) -> None: event.app.exit(result=selected[0]) style = Style.from_dict({'selected': COLOR_GOLD, 'unselected': ''}) @@ -592,12 +596,12 @@ def cli_confirm( return app.run(in_thread=True) -def kb_cancel(): +def kb_cancel() -> KeyBindings: """Custom key bindings to handle ESC as a user cancellation.""" bindings = KeyBindings() @bindings.add('escape') - def _(event): + def _(event) -> None: event.app.exit(exception=UserCancelledError, style='class:aborting') return bindings diff --git a/openhands/cli/utils.py b/openhands/cli/utils.py index ac806c95b1..276f31d614 100644 --- a/openhands/cli/utils.py +++ b/openhands/cli/utils.py @@ -1,6 +1,7 @@ from pathlib import Path import toml +from pydantic import BaseModel, Field from openhands.cli.tui import ( UsageMetrics, @@ -24,7 +25,7 @@ def get_local_config_trusted_dirs() -> list[str]: return [] -def add_local_config_trusted_dir(folder_path: str): +def add_local_config_trusted_dir(folder_path: str) -> None: config = _DEFAULT_CONFIG if _LOCAL_CONFIG_FILE_PATH.exists(): try: @@ -47,7 +48,7 @@ def add_local_config_trusted_dir(folder_path: str): toml.dump(config, f) -def update_usage_metrics(event: Event, usage_metrics: UsageMetrics): +def update_usage_metrics(event: Event, usage_metrics: UsageMetrics) -> None: if not hasattr(event, 'llm_metrics'): return @@ -58,7 +59,34 @@ def update_usage_metrics(event: Event, usage_metrics: UsageMetrics): usage_metrics.metrics = llm_metrics -def extract_model_and_provider(model): +class ModelInfo(BaseModel): + """Information about a model and its provider.""" + + provider: str = Field(description='The provider of the model') + model: str = Field(description='The model identifier') + separator: str = Field(description='The separator used in the model identifier') + + def __getitem__(self, key: str) -> str: + """Allow dictionary-like access to fields.""" + if key == 'provider': + return self.provider + elif key == 'model': + return self.model + elif key == 'separator': + return self.separator + raise KeyError(f'ModelInfo has no key {key}') + + +def extract_model_and_provider(model: str) -> ModelInfo: + """ + Extract provider and model information from a model identifier. + + Args: + model: The model identifier string + + Returns: + A ModelInfo object containing provider, model, and separator information + """ separator = '/' split = model.split(separator) @@ -72,25 +100,36 @@ def extract_model_and_provider(model): if len(split) == 1: # no "/" or "." separator found if split[0] in VERIFIED_OPENAI_MODELS: - return {'provider': 'openai', 'model': split[0], 'separator': '/'} + return ModelInfo(provider='openai', model=split[0], separator='/') if split[0] in VERIFIED_ANTHROPIC_MODELS: - return {'provider': 'anthropic', 'model': split[0], 'separator': '/'} + return ModelInfo(provider='anthropic', model=split[0], separator='/') # return as model only - return {'provider': '', 'model': model, 'separator': ''} + return ModelInfo(provider='', model=model, separator='') provider = split[0] model_id = separator.join(split[1:]) - return {'provider': provider, 'model': model_id, 'separator': separator} + return ModelInfo(provider=provider, model=model_id, separator=separator) -def organize_models_and_providers(models): - result = {} +def organize_models_and_providers( + models: list[str], +) -> dict[str, 'ProviderInfo']: + """ + Organize a list of model identifiers by provider. + + Args: + models: List of model identifiers + + Returns: + A mapping of providers to their information and models + """ + result_dict: dict[str, ProviderInfo] = {} for model in models: extracted = extract_model_and_provider(model) - separator = extracted['separator'] - provider = extracted['provider'] - model_id = extracted['model'] + separator = extracted.separator + provider = extracted.provider + model_id = extracted.model # Ignore "anthropic" providers with a separator of "." # These are outdated and incompatible providers. @@ -98,12 +137,12 @@ def organize_models_and_providers(models): continue key = provider or 'other' - if key not in result: - result[key] = {'separator': separator, 'models': []} + if key not in result_dict: + result_dict[key] = ProviderInfo(separator=separator, models=[]) - result[key]['models'].append(model_id) + result_dict[key].models.append(model_id) - return result + return result_dict VERIFIED_PROVIDERS = ['openai', 'azure', 'anthropic', 'deepseek'] @@ -133,19 +172,48 @@ VERIFIED_ANTHROPIC_MODELS = [ ] -def is_number(char): +class ProviderInfo(BaseModel): + """Information about a provider and its models.""" + + separator: str = Field(description='The separator used in model identifiers') + models: list[str] = Field( + default_factory=list, description='List of model identifiers' + ) + + def __getitem__(self, key: str) -> str | list[str]: + """Allow dictionary-like access to fields.""" + if key == 'separator': + return self.separator + elif key == 'models': + return self.models + raise KeyError(f'ProviderInfo has no key {key}') + + def get(self, key: str, default=None) -> str | list[str] | None: + """Dictionary-like get method with default value.""" + try: + return self[key] + except KeyError: + return default + + +def is_number(char: str) -> bool: return char.isdigit() -def split_is_actually_version(split): - return len(split) > 1 and split[1] and split[1][0] and is_number(split[1][0]) +def split_is_actually_version(split: list[str]) -> bool: + return ( + len(split) > 1 + and bool(split[1]) + and bool(split[1][0]) + and is_number(split[1][0]) + ) -def read_file(file_path): +def read_file(file_path: str | Path) -> str: with open(file_path, 'r') as f: return f.read() -def write_to_file(file_path, content): +def write_to_file(file_path: str | Path, content: str) -> None: with open(file_path, 'w') as f: f.write(content)