mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Fix mypy errors in core directory (#6901)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -25,14 +25,20 @@ def get_field_info(field: FieldInfo) -> dict[str, Any]:
|
||||
# Note: this only works for UnionTypes with None as one of the types
|
||||
if get_origin(field_type) is UnionType:
|
||||
types = get_args(field_type)
|
||||
non_none_arg = next((t for t in types if t is not type(None)), None)
|
||||
non_none_arg = next(
|
||||
(t for t in types if t is not None and t is not type(None)), None
|
||||
)
|
||||
if non_none_arg is not None:
|
||||
field_type = non_none_arg
|
||||
optional = True
|
||||
|
||||
# type name in a pretty format
|
||||
type_name = (
|
||||
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
|
||||
str(field_type)
|
||||
if field_type is None
|
||||
else (
|
||||
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
|
||||
)
|
||||
)
|
||||
|
||||
# default is always present
|
||||
|
||||
@@ -10,17 +10,17 @@ class AgentError(Exception):
|
||||
|
||||
|
||||
class AgentNoInstructionError(AgentError):
|
||||
def __init__(self, message='Instruction must be provided'):
|
||||
def __init__(self, message: str = 'Instruction must be provided') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentEventTypeError(AgentError):
|
||||
def __init__(self, message='Event must be a dictionary'):
|
||||
def __init__(self, message: str = 'Event must be a dictionary') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentAlreadyRegisteredError(AgentError):
|
||||
def __init__(self, name=None):
|
||||
def __init__(self, name: str | None = None) -> None:
|
||||
if name is not None:
|
||||
message = f"Agent class already registered under '{name}'"
|
||||
else:
|
||||
@@ -29,7 +29,7 @@ class AgentAlreadyRegisteredError(AgentError):
|
||||
|
||||
|
||||
class AgentNotRegisteredError(AgentError):
|
||||
def __init__(self, name=None):
|
||||
def __init__(self, name: str | None = None) -> None:
|
||||
if name is not None:
|
||||
message = f"No agent class registered under '{name}'"
|
||||
else:
|
||||
@@ -38,7 +38,7 @@ class AgentNotRegisteredError(AgentError):
|
||||
|
||||
|
||||
class AgentStuckInLoopError(AgentError):
|
||||
def __init__(self, message='Agent got stuck in a loop'):
|
||||
def __init__(self, message: str = 'Agent got stuck in a loop') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class AgentStuckInLoopError(AgentError):
|
||||
|
||||
|
||||
class TaskInvalidStateError(Exception):
|
||||
def __init__(self, state=None):
|
||||
def __init__(self, state: str | None = None) -> None:
|
||||
if state is not None:
|
||||
message = f'Invalid state {state}'
|
||||
else:
|
||||
@@ -64,45 +64,47 @@ class TaskInvalidStateError(Exception):
|
||||
# This exception gets sent back to the LLM
|
||||
# It might be malformed JSON
|
||||
class LLMMalformedActionError(Exception):
|
||||
def __init__(self, message='Malformed response'):
|
||||
def __init__(self, message: str = 'Malformed response') -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# For some reason, the agent did not return an action
|
||||
class LLMNoActionError(Exception):
|
||||
def __init__(self, message='Agent must return an action'):
|
||||
def __init__(self, message: str = 'Agent must return an action') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# The LLM output did not include an action, or the action was not the expected type
|
||||
class LLMResponseError(Exception):
|
||||
def __init__(self, message='Failed to retrieve action from LLM response'):
|
||||
def __init__(
|
||||
self, message: str = 'Failed to retrieve action from LLM response'
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserCancelledError(Exception):
|
||||
def __init__(self, message='User cancelled the request'):
|
||||
def __init__(self, message: str = 'User cancelled the request') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OperationCancelled(Exception):
|
||||
"""Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""
|
||||
|
||||
def __init__(self, message='Operation was cancelled'):
|
||||
def __init__(self, message: str = 'Operation was cancelled') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMContextWindowExceedError(RuntimeError):
|
||||
def __init__(
|
||||
self,
|
||||
message='Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error',
|
||||
):
|
||||
message: str = 'Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error',
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -117,7 +119,7 @@ class FunctionCallConversionError(Exception):
|
||||
This typically happens when there's a malformed message (e.g., missing <function=...> tags). But not due to LLM output.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -127,14 +129,14 @@ class FunctionCallValidationError(Exception):
|
||||
This typically happens when the LLM outputs unrecognized function call / parameter names / values.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class FunctionCallNotExistsError(Exception):
|
||||
"""Exception raised when an LLM call a tool that is not registered."""
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -191,15 +193,17 @@ class AgentRuntimeNotFoundError(AgentRuntimeUnavailableError):
|
||||
|
||||
|
||||
class BrowserInitException(Exception):
|
||||
def __init__(self, message='Failed to initialize browser environment'):
|
||||
def __init__(
|
||||
self, message: str = 'Failed to initialize browser environment'
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserUnavailableException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message='Browser environment is not available, please check if has been initialized',
|
||||
):
|
||||
message: str = 'Browser environment is not available, please check if has been initialized',
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -217,5 +221,5 @@ class MicroAgentError(Exception):
|
||||
class MicroAgentValidationError(MicroAgentError):
|
||||
"""Raised when there's a validation error in microagent metadata."""
|
||||
|
||||
def __init__(self, message='Micro agent validation failed'):
|
||||
def __init__(self, message: str = 'Micro agent validation failed') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
@@ -74,10 +74,11 @@ LOG_COLORS: Mapping[str, ColorType] = {
|
||||
|
||||
|
||||
class StackInfoFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.levelno >= logging.ERROR:
|
||||
record.stack_info = True
|
||||
record.exc_info = True
|
||||
# LogRecord attributes are dynamically typed
|
||||
setattr(record, 'stack_info', True)
|
||||
setattr(record, 'exc_info', sys.exc_info())
|
||||
return True
|
||||
|
||||
|
||||
@@ -107,9 +108,9 @@ def strip_ansi(s: str) -> str:
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
msg_type = record.__dict__.get('msg_type')
|
||||
event_source = record.__dict__.get('event_source')
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
msg_type = record.__dict__.get('msg_type', '')
|
||||
event_source = record.__dict__.get('event_source', '')
|
||||
if event_source:
|
||||
new_msg_type = f'{event_source.upper()}_{msg_type}'
|
||||
if new_msg_type in LOG_COLORS:
|
||||
@@ -136,12 +137,13 @@ class ColoredFormatter(logging.Formatter):
|
||||
return super().format(new_record)
|
||||
|
||||
|
||||
def _fix_record(record: logging.LogRecord):
|
||||
def _fix_record(record: logging.LogRecord) -> logging.LogRecord:
|
||||
new_record = copy.copy(record)
|
||||
# The formatter expects non boolean values, and will raise an exception if there is a boolean - so we fix these
|
||||
if new_record.exc_info is True and not new_record.exc_text: # type: ignore
|
||||
new_record.exc_info = sys.exc_info() # type: ignore
|
||||
new_record.stack_info = None # type: ignore
|
||||
# LogRecord attributes are dynamically typed
|
||||
if getattr(new_record, 'exc_info', None) is True:
|
||||
setattr(new_record, 'exc_info', sys.exc_info())
|
||||
setattr(new_record, 'stack_info', None)
|
||||
return new_record
|
||||
|
||||
|
||||
@@ -158,32 +160,32 @@ class RollingLogger:
|
||||
log_lines: list[str]
|
||||
all_lines: str
|
||||
|
||||
def __init__(self, max_lines=10, char_limit=80):
|
||||
def __init__(self, max_lines: int = 10, char_limit: int = 80) -> None:
|
||||
self.max_lines = max_lines
|
||||
self.char_limit = char_limit
|
||||
self.log_lines = [''] * self.max_lines
|
||||
self.all_lines = ''
|
||||
|
||||
def is_enabled(self):
|
||||
def is_enabled(self) -> bool:
|
||||
return DEBUG and sys.stdout.isatty()
|
||||
|
||||
def start(self, message=''):
|
||||
def start(self, message: str = '') -> None:
|
||||
if message:
|
||||
print(message)
|
||||
self._write('\n' * self.max_lines)
|
||||
self._flush()
|
||||
|
||||
def add_line(self, line):
|
||||
def add_line(self, line: str) -> None:
|
||||
self.log_lines.pop(0)
|
||||
self.log_lines.append(line[: self.char_limit])
|
||||
self.print_lines()
|
||||
self.all_lines += line + '\n'
|
||||
|
||||
def write_immediately(self, line):
|
||||
def write_immediately(self, line: str) -> None:
|
||||
self._write(line)
|
||||
self._flush()
|
||||
|
||||
def print_lines(self):
|
||||
def print_lines(self) -> None:
|
||||
"""Display the last n log_lines in the console (not for file logging).
|
||||
|
||||
This will create the effect of a rolling display in the console.
|
||||
@@ -192,31 +194,31 @@ class RollingLogger:
|
||||
for line in self.log_lines:
|
||||
self.replace_current_line(line)
|
||||
|
||||
def move_back(self, amount=-1):
|
||||
def move_back(self, amount: int = -1) -> None:
|
||||
r"""'\033[F' moves the cursor up one line."""
|
||||
if amount == -1:
|
||||
amount = self.max_lines
|
||||
self._write('\033[F' * (self.max_lines))
|
||||
self._flush()
|
||||
|
||||
def replace_current_line(self, line=''):
|
||||
def replace_current_line(self, line: str = '') -> None:
|
||||
r"""'\033[2K\r' clears the line and moves the cursor to the beginning of the line."""
|
||||
self._write('\033[2K' + line + '\n')
|
||||
self._flush()
|
||||
|
||||
def _write(self, line):
|
||||
def _write(self, line: str) -> None:
|
||||
if not self.is_enabled():
|
||||
return
|
||||
sys.stdout.write(line)
|
||||
|
||||
def _flush(self):
|
||||
def _flush(self) -> None:
|
||||
if not self.is_enabled():
|
||||
return
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
class SensitiveDataFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Gather sensitive values which should not ever appear in the logs.
|
||||
sensitive_values = []
|
||||
for key, value in os.environ.items():
|
||||
@@ -263,7 +265,9 @@ class SensitiveDataFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
def get_console_handler(log_level: int = logging.INFO, extra_info: str | None = None):
|
||||
def get_console_handler(
|
||||
log_level: int = logging.INFO, extra_info: str | None = None
|
||||
) -> logging.StreamHandler:
|
||||
"""Returns a console handler for logging."""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
@@ -274,7 +278,9 @@ def get_console_handler(log_level: int = logging.INFO, extra_info: str | None =
|
||||
return console_handler
|
||||
|
||||
|
||||
def get_file_handler(log_dir: str, log_level: int = logging.INFO):
|
||||
def get_file_handler(
|
||||
log_dir: str, log_level: int = logging.INFO
|
||||
) -> logging.FileHandler:
|
||||
"""Returns a file handler for logging."""
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d')
|
||||
@@ -348,7 +354,13 @@ logging.getLogger('LiteLLM Proxy').disabled = True
|
||||
class LlmFileHandler(logging.FileHandler):
|
||||
"""LLM prompt and response logging."""
|
||||
|
||||
def __init__(self, filename, mode='a', encoding='utf-8', delay=False):
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
mode: str = 'a',
|
||||
encoding: str = 'utf-8',
|
||||
delay: bool = False,
|
||||
) -> None:
|
||||
"""Initializes an instance of LlmFileHandler.
|
||||
|
||||
Args:
|
||||
@@ -379,7 +391,7 @@ class LlmFileHandler(logging.FileHandler):
|
||||
self.baseFilename = os.path.join(self.log_directory, filename)
|
||||
super().__init__(self.baseFilename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""Emits a log record.
|
||||
|
||||
Args:
|
||||
@@ -394,7 +406,7 @@ class LlmFileHandler(logging.FileHandler):
|
||||
self.message_counter += 1
|
||||
|
||||
|
||||
def _get_llm_file_handler(name: str, log_level: int):
|
||||
def _get_llm_file_handler(name: str, log_level: int) -> LlmFileHandler:
|
||||
# The 'delay' parameter, when set to True, postpones the opening of the log file
|
||||
# until the first log message is emitted.
|
||||
llm_file_handler = LlmFileHandler(name, delay=True)
|
||||
@@ -403,7 +415,7 @@ def _get_llm_file_handler(name: str, log_level: int):
|
||||
return llm_file_handler
|
||||
|
||||
|
||||
def _setup_llm_logger(name: str, log_level: int):
|
||||
def _setup_llm_logger(name: str, log_level: int) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(log_level)
|
||||
|
||||
@@ -15,7 +15,9 @@ class Content(BaseModel):
|
||||
cache_prompt: bool = False
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
def serialize_model(
|
||||
self,
|
||||
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]:
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class TextContent(Content):
|
||||
text: str
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
def serialize_model(self) -> dict[str, str | dict[str, str]]:
|
||||
data: dict[str, str | dict[str, str]] = {
|
||||
'type': self.type,
|
||||
'text': self.text,
|
||||
@@ -39,7 +41,7 @@ class ImageContent(Content):
|
||||
image_urls: list[str]
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
def serialize_model(self) -> list[dict[str, str | dict[str, str]]]:
|
||||
images: list[dict[str, str | dict[str, str]]] = []
|
||||
for url in self.image_urls:
|
||||
images.append({'type': self.type, 'image_url': {'url': url}})
|
||||
@@ -101,15 +103,22 @@ class Message(BaseModel):
|
||||
# See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472
|
||||
if self.role == 'tool' and item.cache_prompt:
|
||||
role_tool_with_prompt_caching = True
|
||||
if isinstance(d, dict):
|
||||
d.pop('cache_control')
|
||||
elif isinstance(d, list):
|
||||
for d_item in d:
|
||||
d_item.pop('cache_control')
|
||||
if isinstance(item, TextContent):
|
||||
d.pop('cache_control', None)
|
||||
elif isinstance(item, ImageContent):
|
||||
# ImageContent.model_dump() always returns a list
|
||||
# We know d is a list of dicts for ImageContent
|
||||
if hasattr(d, '__iter__'):
|
||||
for d_item in d:
|
||||
if hasattr(d_item, 'pop'):
|
||||
d_item.pop('cache_control', None)
|
||||
|
||||
if isinstance(item, TextContent):
|
||||
content.append(d)
|
||||
elif isinstance(item, ImageContent) and self.vision_enabled:
|
||||
content.extend(d)
|
||||
# ImageContent.model_dump() always returns a list
|
||||
# We know d is a list for ImageContent
|
||||
content.extend([d] if isinstance(d, dict) else d)
|
||||
|
||||
message_dict: dict = {'content': content, 'role': self.role}
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def get_action_message(
|
||||
)
|
||||
|
||||
llm_response: ModelResponse = tool_metadata.model_response
|
||||
assistant_msg = llm_response.choices[0].message
|
||||
assistant_msg = getattr(llm_response.choices[0], 'message')
|
||||
|
||||
# Add the LLM message (assistant) that initiated the tool calls
|
||||
# (overwrites any previous message with the same response_id)
|
||||
@@ -168,7 +168,7 @@ def get_action_message(
|
||||
f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}'
|
||||
)
|
||||
pending_tool_call_action_messages[llm_response.id] = Message(
|
||||
role=assistant_msg.role,
|
||||
role=getattr(assistant_msg, 'role', 'assistant'),
|
||||
# tool call content SHOULD BE a string
|
||||
content=[TextContent(text=assistant_msg.content or '')]
|
||||
if assistant_msg.content is not None
|
||||
@@ -185,7 +185,7 @@ def get_action_message(
|
||||
tool_metadata = action.tool_call_metadata
|
||||
if tool_metadata is not None:
|
||||
# take the response message from the tool call
|
||||
assistant_msg = tool_metadata.model_response.choices[0].message
|
||||
assistant_msg = getattr(tool_metadata.model_response.choices[0], 'message')
|
||||
content = assistant_msg.content or ''
|
||||
|
||||
# save content if any, to thought
|
||||
@@ -197,9 +197,11 @@ def get_action_message(
|
||||
|
||||
# remove the tool call metadata
|
||||
action.tool_call_metadata = None
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role,
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=[TextContent(text=action.thought)],
|
||||
)
|
||||
]
|
||||
@@ -208,9 +210,11 @@ def get_action_message(
|
||||
content = [TextContent(text=action.content or '')]
|
||||
if vision_is_active and action.image_urls:
|
||||
content.append(ImageContent(image_urls=action.image_urls))
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role,
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
@@ -218,7 +222,7 @@ def get_action_message(
|
||||
content = [TextContent(text=f'User executed the command:\n{action.command}')]
|
||||
return [
|
||||
Message(
|
||||
role='user',
|
||||
role='user', # Always user for CmdRunAction
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user