Fix mypy errors in core directory (#6901)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-02-24 10:00:57 -05:00
committed by GitHub
parent 753e3c4205
commit 8956f92f6a
5 changed files with 100 additions and 65 deletions

View File

@@ -25,14 +25,20 @@ def get_field_info(field: FieldInfo) -> dict[str, Any]:
# Note: this only works for UnionTypes with None as one of the types
if get_origin(field_type) is UnionType:
types = get_args(field_type)
non_none_arg = next((t for t in types if t is not type(None)), None)
non_none_arg = next(
(t for t in types if t is not None and t is not type(None)), None
)
if non_none_arg is not None:
field_type = non_none_arg
optional = True
# type name in a pretty format
type_name = (
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
str(field_type)
if field_type is None
else (
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
)
)
# default is always present

View File

@@ -10,17 +10,17 @@ class AgentError(Exception):
class AgentNoInstructionError(AgentError):
def __init__(self, message='Instruction must be provided'):
def __init__(self, message: str = 'Instruction must be provided') -> None:
super().__init__(message)
class AgentEventTypeError(AgentError):
def __init__(self, message='Event must be a dictionary'):
def __init__(self, message: str = 'Event must be a dictionary') -> None:
super().__init__(message)
class AgentAlreadyRegisteredError(AgentError):
def __init__(self, name=None):
def __init__(self, name: str | None = None) -> None:
if name is not None:
message = f"Agent class already registered under '{name}'"
else:
@@ -29,7 +29,7 @@ class AgentAlreadyRegisteredError(AgentError):
class AgentNotRegisteredError(AgentError):
def __init__(self, name=None):
def __init__(self, name: str | None = None) -> None:
if name is not None:
message = f"No agent class registered under '{name}'"
else:
@@ -38,7 +38,7 @@ class AgentNotRegisteredError(AgentError):
class AgentStuckInLoopError(AgentError):
def __init__(self, message='Agent got stuck in a loop'):
def __init__(self, message: str = 'Agent got stuck in a loop') -> None:
super().__init__(message)
@@ -48,7 +48,7 @@ class AgentStuckInLoopError(AgentError):
class TaskInvalidStateError(Exception):
def __init__(self, state=None):
def __init__(self, state: str | None = None) -> None:
if state is not None:
message = f'Invalid state {state}'
else:
@@ -64,45 +64,47 @@ class TaskInvalidStateError(Exception):
# This exception gets sent back to the LLM
# It might be malformed JSON
class LLMMalformedActionError(Exception):
def __init__(self, message='Malformed response'):
def __init__(self, message: str = 'Malformed response') -> None:
self.message = message
super().__init__(message)
def __str__(self):
def __str__(self) -> str:
return self.message
# This exception gets sent back to the LLM
# For some reason, the agent did not return an action
class LLMNoActionError(Exception):
def __init__(self, message='Agent must return an action'):
def __init__(self, message: str = 'Agent must return an action') -> None:
super().__init__(message)
# This exception gets sent back to the LLM
# The LLM output did not include an action, or the action was not the expected type
class LLMResponseError(Exception):
def __init__(self, message='Failed to retrieve action from LLM response'):
def __init__(
self, message: str = 'Failed to retrieve action from LLM response'
) -> None:
super().__init__(message)
class UserCancelledError(Exception):
def __init__(self, message='User cancelled the request'):
def __init__(self, message: str = 'User cancelled the request') -> None:
super().__init__(message)
class OperationCancelled(Exception):
"""Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""
def __init__(self, message='Operation was cancelled'):
def __init__(self, message: str = 'Operation was cancelled') -> None:
super().__init__(message)
class LLMContextWindowExceedError(RuntimeError):
def __init__(
self,
message='Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error',
):
message: str = 'Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error',
) -> None:
super().__init__(message)
@@ -117,7 +119,7 @@ class FunctionCallConversionError(Exception):
This typically happens when there's a malformed message (e.g., missing <function=...> tags). But not due to LLM output.
"""
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(message)
@@ -127,14 +129,14 @@ class FunctionCallValidationError(Exception):
This typically happens when the LLM outputs unrecognized function call / parameter names / values.
"""
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(message)
class FunctionCallNotExistsError(Exception):
"""Exception raised when an LLM call a tool that is not registered."""
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(message)
@@ -191,15 +193,17 @@ class AgentRuntimeNotFoundError(AgentRuntimeUnavailableError):
class BrowserInitException(Exception):
def __init__(self, message='Failed to initialize browser environment'):
def __init__(
self, message: str = 'Failed to initialize browser environment'
) -> None:
super().__init__(message)
class BrowserUnavailableException(Exception):
def __init__(
self,
message='Browser environment is not available, please check if has been initialized',
):
message: str = 'Browser environment is not available, please check if has been initialized',
) -> None:
super().__init__(message)
@@ -217,5 +221,5 @@ class MicroAgentError(Exception):
class MicroAgentValidationError(MicroAgentError):
"""Raised when there's a validation error in microagent metadata."""
def __init__(self, message='Micro agent validation failed'):
def __init__(self, message: str = 'Micro agent validation failed') -> None:
super().__init__(message)

View File

@@ -74,10 +74,11 @@ LOG_COLORS: Mapping[str, ColorType] = {
class StackInfoFilter(logging.Filter):
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
if record.levelno >= logging.ERROR:
record.stack_info = True
record.exc_info = True
# LogRecord attributes are dynamically typed
setattr(record, 'stack_info', True)
setattr(record, 'exc_info', sys.exc_info())
return True
@@ -107,9 +108,9 @@ def strip_ansi(s: str) -> str:
class ColoredFormatter(logging.Formatter):
def format(self, record):
msg_type = record.__dict__.get('msg_type')
event_source = record.__dict__.get('event_source')
def format(self, record: logging.LogRecord) -> str:
msg_type = record.__dict__.get('msg_type', '')
event_source = record.__dict__.get('event_source', '')
if event_source:
new_msg_type = f'{event_source.upper()}_{msg_type}'
if new_msg_type in LOG_COLORS:
@@ -136,12 +137,13 @@ class ColoredFormatter(logging.Formatter):
return super().format(new_record)
def _fix_record(record: logging.LogRecord):
def _fix_record(record: logging.LogRecord) -> logging.LogRecord:
new_record = copy.copy(record)
# The formatter expects non boolean values, and will raise an exception if there is a boolean - so we fix these
if new_record.exc_info is True and not new_record.exc_text: # type: ignore
new_record.exc_info = sys.exc_info() # type: ignore
new_record.stack_info = None # type: ignore
# LogRecord attributes are dynamically typed
if getattr(new_record, 'exc_info', None) is True:
setattr(new_record, 'exc_info', sys.exc_info())
setattr(new_record, 'stack_info', None)
return new_record
@@ -158,32 +160,32 @@ class RollingLogger:
log_lines: list[str]
all_lines: str
def __init__(self, max_lines=10, char_limit=80):
def __init__(self, max_lines: int = 10, char_limit: int = 80) -> None:
self.max_lines = max_lines
self.char_limit = char_limit
self.log_lines = [''] * self.max_lines
self.all_lines = ''
def is_enabled(self):
def is_enabled(self) -> bool:
return DEBUG and sys.stdout.isatty()
def start(self, message=''):
def start(self, message: str = '') -> None:
if message:
print(message)
self._write('\n' * self.max_lines)
self._flush()
def add_line(self, line):
def add_line(self, line: str) -> None:
self.log_lines.pop(0)
self.log_lines.append(line[: self.char_limit])
self.print_lines()
self.all_lines += line + '\n'
def write_immediately(self, line):
def write_immediately(self, line: str) -> None:
self._write(line)
self._flush()
def print_lines(self):
def print_lines(self) -> None:
"""Display the last n log_lines in the console (not for file logging).
This will create the effect of a rolling display in the console.
@@ -192,31 +194,31 @@ class RollingLogger:
for line in self.log_lines:
self.replace_current_line(line)
def move_back(self, amount=-1):
def move_back(self, amount: int = -1) -> None:
r"""'\033[F' moves the cursor up one line."""
if amount == -1:
amount = self.max_lines
self._write('\033[F' * (self.max_lines))
self._flush()
def replace_current_line(self, line=''):
def replace_current_line(self, line: str = '') -> None:
r"""'\033[2K\r' clears the line and moves the cursor to the beginning of the line."""
self._write('\033[2K' + line + '\n')
self._flush()
def _write(self, line):
def _write(self, line: str) -> None:
if not self.is_enabled():
return
sys.stdout.write(line)
def _flush(self):
def _flush(self) -> None:
if not self.is_enabled():
return
sys.stdout.flush()
class SensitiveDataFilter(logging.Filter):
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
# Gather sensitive values which should not ever appear in the logs.
sensitive_values = []
for key, value in os.environ.items():
@@ -263,7 +265,9 @@ class SensitiveDataFilter(logging.Filter):
return True
def get_console_handler(log_level: int = logging.INFO, extra_info: str | None = None):
def get_console_handler(
log_level: int = logging.INFO, extra_info: str | None = None
) -> logging.StreamHandler:
"""Returns a console handler for logging."""
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
@@ -274,7 +278,9 @@ def get_console_handler(log_level: int = logging.INFO, extra_info: str | None =
return console_handler
def get_file_handler(log_dir: str, log_level: int = logging.INFO):
def get_file_handler(
log_dir: str, log_level: int = logging.INFO
) -> logging.FileHandler:
"""Returns a file handler for logging."""
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y-%m-%d')
@@ -348,7 +354,13 @@ logging.getLogger('LiteLLM Proxy').disabled = True
class LlmFileHandler(logging.FileHandler):
"""LLM prompt and response logging."""
def __init__(self, filename, mode='a', encoding='utf-8', delay=False):
def __init__(
self,
filename: str,
mode: str = 'a',
encoding: str = 'utf-8',
delay: bool = False,
) -> None:
"""Initializes an instance of LlmFileHandler.
Args:
@@ -379,7 +391,7 @@ class LlmFileHandler(logging.FileHandler):
self.baseFilename = os.path.join(self.log_directory, filename)
super().__init__(self.baseFilename, mode, encoding, delay)
def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
"""Emits a log record.
Args:
@@ -394,7 +406,7 @@ class LlmFileHandler(logging.FileHandler):
self.message_counter += 1
def _get_llm_file_handler(name: str, log_level: int):
def _get_llm_file_handler(name: str, log_level: int) -> LlmFileHandler:
# The 'delay' parameter, when set to True, postpones the opening of the log file
# until the first log message is emitted.
llm_file_handler = LlmFileHandler(name, delay=True)
@@ -403,7 +415,7 @@ def _get_llm_file_handler(name: str, log_level: int):
return llm_file_handler
def _setup_llm_logger(name: str, log_level: int):
def _setup_llm_logger(name: str, log_level: int) -> logging.Logger:
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(log_level)

View File

@@ -15,7 +15,9 @@ class Content(BaseModel):
cache_prompt: bool = False
@model_serializer
def serialize_model(self):
def serialize_model(
self,
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]:
raise NotImplementedError('Subclasses should implement this method.')
@@ -24,7 +26,7 @@ class TextContent(Content):
text: str
@model_serializer
def serialize_model(self):
def serialize_model(self) -> dict[str, str | dict[str, str]]:
data: dict[str, str | dict[str, str]] = {
'type': self.type,
'text': self.text,
@@ -39,7 +41,7 @@ class ImageContent(Content):
image_urls: list[str]
@model_serializer
def serialize_model(self):
def serialize_model(self) -> list[dict[str, str | dict[str, str]]]:
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
images.append({'type': self.type, 'image_url': {'url': url}})
@@ -101,15 +103,22 @@ class Message(BaseModel):
# See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472
if self.role == 'tool' and item.cache_prompt:
role_tool_with_prompt_caching = True
if isinstance(d, dict):
d.pop('cache_control')
elif isinstance(d, list):
for d_item in d:
d_item.pop('cache_control')
if isinstance(item, TextContent):
d.pop('cache_control', None)
elif isinstance(item, ImageContent):
# ImageContent.model_dump() always returns a list
# We know d is a list of dicts for ImageContent
if hasattr(d, '__iter__'):
for d_item in d:
if hasattr(d_item, 'pop'):
d_item.pop('cache_control', None)
if isinstance(item, TextContent):
content.append(d)
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)
# ImageContent.model_dump() always returns a list
# We know d is a list for ImageContent
content.extend([d] if isinstance(d, dict) else d)
message_dict: dict = {'content': content, 'role': self.role}

View File

@@ -160,7 +160,7 @@ def get_action_message(
)
llm_response: ModelResponse = tool_metadata.model_response
assistant_msg = llm_response.choices[0].message
assistant_msg = getattr(llm_response.choices[0], 'message')
# Add the LLM message (assistant) that initiated the tool calls
# (overwrites any previous message with the same response_id)
@@ -168,7 +168,7 @@ def get_action_message(
f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}'
)
pending_tool_call_action_messages[llm_response.id] = Message(
role=assistant_msg.role,
role=getattr(assistant_msg, 'role', 'assistant'),
# tool call content SHOULD BE a string
content=[TextContent(text=assistant_msg.content or '')]
if assistant_msg.content is not None
@@ -185,7 +185,7 @@ def get_action_message(
tool_metadata = action.tool_call_metadata
if tool_metadata is not None:
# take the response message from the tool call
assistant_msg = tool_metadata.model_response.choices[0].message
assistant_msg = getattr(tool_metadata.model_response.choices[0], 'message')
content = assistant_msg.content or ''
# save content if any, to thought
@@ -197,9 +197,11 @@ def get_action_message(
# remove the tool call metadata
action.tool_call_metadata = None
if role not in ('user', 'system', 'assistant', 'tool'):
raise ValueError(f'Invalid role: {role}')
return [
Message(
role=role,
role=role, # type: ignore[arg-type]
content=[TextContent(text=action.thought)],
)
]
@@ -208,9 +210,11 @@ def get_action_message(
content = [TextContent(text=action.content or '')]
if vision_is_active and action.image_urls:
content.append(ImageContent(image_urls=action.image_urls))
if role not in ('user', 'system', 'assistant', 'tool'):
raise ValueError(f'Invalid role: {role}')
return [
Message(
role=role,
role=role, # type: ignore[arg-type]
content=content,
)
]
@@ -218,7 +222,7 @@ def get_action_message(
content = [TextContent(text=f'User executed the command:\n{action.command}')]
return [
Message(
role='user',
role='user', # Always user for CmdRunAction
content=content,
)
]