mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add proper typing to cli directory (#8374)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
5ad11e73b8
commit
60d9b519e0
@ -273,7 +273,6 @@ async def run_session(
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Runs the agent in CLI mode."""
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
@ -15,6 +15,7 @@ from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
|
||||
from prompt_toolkit.input import create_input
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.key_binding.key_processor import KeyPressEvent
|
||||
from prompt_toolkit.keys import Keys
|
||||
from prompt_toolkit.layout.containers import HSplit, Window
|
||||
from prompt_toolkit.layout.controls import FormattedTextControl
|
||||
@ -70,7 +71,7 @@ print_lock = threading.Lock()
|
||||
|
||||
|
||||
class UsageMetrics:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.metrics: Metrics = Metrics()
|
||||
self.session_init_time: float = time.time()
|
||||
|
||||
@ -78,7 +79,7 @@ class UsageMetrics:
|
||||
class CustomDiffLexer(Lexer):
|
||||
"""Custom lexer for the specific diff format."""
|
||||
|
||||
def lex_document(self, document) -> StyleAndTextTuples:
|
||||
def lex_document(self, document: Document) -> StyleAndTextTuples:
|
||||
lines = document.lines
|
||||
|
||||
def get_line(lineno: int) -> StyleAndTextTuples:
|
||||
@ -427,7 +428,7 @@ def display_agent_state_change_message(agent_state: str) -> None:
|
||||
class CommandCompleter(Completer):
|
||||
"""Custom completer for commands."""
|
||||
|
||||
def __init__(self, agent_state: str):
|
||||
def __init__(self, agent_state: str) -> None:
|
||||
super().__init__()
|
||||
self.agent_state = agent_state
|
||||
|
||||
@ -450,7 +451,7 @@ class CommandCompleter(Completer):
|
||||
)
|
||||
|
||||
|
||||
def create_prompt_session() -> PromptSession:
|
||||
def create_prompt_session() -> PromptSession[str]:
|
||||
return PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
|
||||
@ -465,7 +466,7 @@ async def read_prompt_input(agent_state: str, multiline: bool = False) -> str:
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('c-d')
|
||||
def _(event) -> None:
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
event.current_buffer.validate_and_handle()
|
||||
|
||||
with patch_stdout():
|
||||
@ -538,11 +539,10 @@ async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) ->
|
||||
def cli_confirm(
|
||||
question: str = 'Are you sure?', choices: list[str] | None = None
|
||||
) -> int:
|
||||
"""
|
||||
Display a confirmation prompt with the given question and choices.
|
||||
"""Display a confirmation prompt with the given question and choices.
|
||||
|
||||
Returns the index of the selected choice.
|
||||
"""
|
||||
|
||||
if choices is None:
|
||||
choices = ['Yes', 'No']
|
||||
selected = [0] # Using list to allow modification in closure
|
||||
@ -561,15 +561,15 @@ def cli_confirm(
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('up')
|
||||
def _(event) -> None:
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
selected[0] = (selected[0] - 1) % len(choices)
|
||||
|
||||
@kb.add('down')
|
||||
def _(event) -> None:
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
selected[0] = (selected[0] + 1) % len(choices)
|
||||
|
||||
@kb.add('enter')
|
||||
def _(event) -> None:
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
event.app.exit(result=selected[0])
|
||||
|
||||
style = Style.from_dict({'selected': COLOR_GOLD, 'unselected': ''})
|
||||
@ -601,7 +601,7 @@ def kb_cancel() -> KeyBindings:
|
||||
bindings = KeyBindings()
|
||||
|
||||
@bindings.add('escape')
|
||||
def _(event) -> None:
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
event.app.exit(exception=UserCancelledError, style='class:aborting')
|
||||
|
||||
return bindings
|
||||
|
||||
@ -78,8 +78,7 @@ class ModelInfo(BaseModel):
|
||||
|
||||
|
||||
def extract_model_and_provider(model: str) -> ModelInfo:
|
||||
"""
|
||||
Extract provider and model information from a model identifier.
|
||||
"""Extract provider and model information from a model identifier.
|
||||
|
||||
Args:
|
||||
model: The model identifier string
|
||||
@ -114,8 +113,7 @@ def extract_model_and_provider(model: str) -> ModelInfo:
|
||||
def organize_models_and_providers(
|
||||
models: list[str],
|
||||
) -> dict[str, 'ProviderInfo']:
|
||||
"""
|
||||
Organize a list of model identifiers by provider.
|
||||
"""Organize a list of model identifiers by provider.
|
||||
|
||||
Args:
|
||||
models: List of model identifiers
|
||||
@ -188,7 +186,7 @@ class ProviderInfo(BaseModel):
|
||||
return self.models
|
||||
raise KeyError(f'ProviderInfo has no key {key}')
|
||||
|
||||
def get(self, key: str, default=None) -> str | list[str] | None:
|
||||
def get(self, key: str, default: None = None) -> str | list[str] | None:
|
||||
"""Dictionary-like get method with default value."""
|
||||
try:
|
||||
return self[key]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user