Add proper typing to cli directory (#8374)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-05-13 09:55:44 -04:00 committed by GitHub
parent 5ad11e73b8
commit 60d9b519e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 18 deletions

View File

@ -273,7 +273,6 @@ async def run_session(
async def main(loop: asyncio.AbstractEventLoop) -> None:
"""Runs the agent in CLI mode."""
args = parse_arguments()
logger.setLevel(logging.WARNING)

View File

@ -15,6 +15,7 @@ from prompt_toolkit.document import Document
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
from prompt_toolkit.input import create_input
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.key_binding.key_processor import KeyPressEvent
from prompt_toolkit.keys import Keys
from prompt_toolkit.layout.containers import HSplit, Window
from prompt_toolkit.layout.controls import FormattedTextControl
@ -70,7 +71,7 @@ print_lock = threading.Lock()
class UsageMetrics:
def __init__(self):
def __init__(self) -> None:
self.metrics: Metrics = Metrics()
self.session_init_time: float = time.time()
@ -78,7 +79,7 @@ class UsageMetrics:
class CustomDiffLexer(Lexer):
"""Custom lexer for the specific diff format."""
def lex_document(self, document) -> StyleAndTextTuples:
def lex_document(self, document: Document) -> StyleAndTextTuples:
lines = document.lines
def get_line(lineno: int) -> StyleAndTextTuples:
@ -427,7 +428,7 @@ def display_agent_state_change_message(agent_state: str) -> None:
class CommandCompleter(Completer):
"""Custom completer for commands."""
def __init__(self, agent_state: str):
def __init__(self, agent_state: str) -> None:
super().__init__()
self.agent_state = agent_state
@ -450,7 +451,7 @@ class CommandCompleter(Completer):
)
def create_prompt_session() -> PromptSession:
def create_prompt_session() -> PromptSession[str]:
return PromptSession(style=DEFAULT_STYLE)
@ -465,7 +466,7 @@ async def read_prompt_input(agent_state: str, multiline: bool = False) -> str:
kb = KeyBindings()
@kb.add('c-d')
def _(event) -> None:
def _(event: KeyPressEvent) -> None:
event.current_buffer.validate_and_handle()
with patch_stdout():
@ -538,11 +539,10 @@ async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) ->
def cli_confirm(
question: str = 'Are you sure?', choices: list[str] | None = None
) -> int:
"""
Display a confirmation prompt with the given question and choices.
"""Display a confirmation prompt with the given question and choices.
Returns the index of the selected choice.
"""
if choices is None:
choices = ['Yes', 'No']
selected = [0] # Using list to allow modification in closure
@ -561,15 +561,15 @@ def cli_confirm(
kb = KeyBindings()
@kb.add('up')
def _(event) -> None:
def _(event: KeyPressEvent) -> None:
selected[0] = (selected[0] - 1) % len(choices)
@kb.add('down')
def _(event) -> None:
def _(event: KeyPressEvent) -> None:
selected[0] = (selected[0] + 1) % len(choices)
@kb.add('enter')
def _(event) -> None:
def _(event: KeyPressEvent) -> None:
event.app.exit(result=selected[0])
style = Style.from_dict({'selected': COLOR_GOLD, 'unselected': ''})
@ -601,7 +601,7 @@ def kb_cancel() -> KeyBindings:
bindings = KeyBindings()
@bindings.add('escape')
def _(event) -> None:
def _(event: KeyPressEvent) -> None:
event.app.exit(exception=UserCancelledError, style='class:aborting')
return bindings

View File

@ -78,8 +78,7 @@ class ModelInfo(BaseModel):
def extract_model_and_provider(model: str) -> ModelInfo:
"""
Extract provider and model information from a model identifier.
"""Extract provider and model information from a model identifier.
Args:
model: The model identifier string
@ -114,8 +113,7 @@ def extract_model_and_provider(model: str) -> ModelInfo:
def organize_models_and_providers(
models: list[str],
) -> dict[str, 'ProviderInfo']:
"""
Organize a list of model identifiers by provider.
"""Organize a list of model identifiers by provider.
Args:
models: List of model identifiers
@ -188,7 +186,7 @@ class ProviderInfo(BaseModel):
return self.models
raise KeyError(f'ProviderInfo has no key {key}')
def get(self, key: str, default=None) -> str | list[str] | None:
def get(self, key: str, default: None = None) -> str | list[str] | None:
"""Dictionary-like get method with default value."""
try:
return self[key]