Fix type checking errors in resolver directory (#6738)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-02-18 20:13:33 -05:00 committed by GitHub
parent 1a7003a705
commit f4e5fb2873
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 209 additions and 230 deletions

View File

@ -22,28 +22,28 @@ class GithubIssueHandler(IssueHandlerInterface):
self.clone_url = self.get_clone_url()
self.headers = self.get_headers()
def set_owner(self, owner: str):
def set_owner(self, owner: str) -> None:
self.owner = owner
def get_headers(self):
def get_headers(self) -> dict[str, str]:
return {
'Authorization': f'token {self.token}',
'Accept': 'application/vnd.github.v3+json',
}
def get_base_url(self):
def get_base_url(self) -> str:
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
def get_authorize_url(self):
def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@github.com/'
def get_branch_url(self, branch_name: str):
def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/branches/{branch_name}'
def get_download_url(self):
def get_download_url(self) -> str:
return f'{self.base_url}/issues'
def get_clone_url(self):
def get_clone_url(self) -> str:
username_and_token = (
f'{self.username}:{self.token}'
if self.username
@ -51,10 +51,10 @@ class GithubIssueHandler(IssueHandlerInterface):
)
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
def get_graphql_url(self):
def get_graphql_url(self) -> str:
return 'https://api.github.com/graphql'
def get_compare_url(self, branch_name: str):
def get_compare_url(self, branch_name: str) -> str:
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
def get_converted_issues(
@ -186,7 +186,7 @@ class GithubIssueHandler(IssueHandlerInterface):
print(f'Branch {branch_name} exists: {exists}')
return exists
def get_branch_name(self, base_branch_name: str):
def get_branch_name(self, base_branch_name: str) -> str:
branch_name = base_branch_name
attempt = 1
while self.branch_exists(branch_name):
@ -194,7 +194,7 @@ class GithubIssueHandler(IssueHandlerInterface):
branch_name = f'{base_branch_name}-try{attempt}'
return branch_name
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
# Opting for graphql as REST API doesn't allow reply to replies in comment threads
query = """
mutation($body: String!, $pullRequestReviewThreadId: ID!) {
@ -221,15 +221,18 @@ class GithubIssueHandler(IssueHandlerInterface):
)
response.raise_for_status()
def get_pull_url(self, pr_number: int):
def get_pull_url(self, pr_number: int) -> str:
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
def get_default_branch_name(self) -> str:
response = requests.get(f'{self.base_url}', headers=self.headers)
response.raise_for_status()
return response.json()['default_branch']
data = response.json()
return str(data['default_branch'])
def create_pull_request(self, data=dict) -> dict:
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
response = requests.post(
f'{self.base_url}/pulls', headers=self.headers, json=data
)
@ -240,9 +243,9 @@ class GithubIssueHandler(IssueHandlerInterface):
)
response.raise_for_status()
pr_data = response.json()
return pr_data
return dict(pr_data)
def request_reviewers(self, reviewer: str, pr_number: int):
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
review_data = {'reviewers': [reviewer]}
review_response = requests.post(
f'{self.base_url}/pulls/{pr_number}/requested_reviewers',
@ -254,7 +257,7 @@ class GithubIssueHandler(IssueHandlerInterface):
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
)
def send_comment_msg(self, issue_number: int, msg: str):
def send_comment_msg(self, issue_number: int, msg: str) -> None:
"""Send a comment message to a GitHub issue or pull request.
Args:
@ -282,8 +285,8 @@ class GithubIssueHandler(IssueHandlerInterface):
review_comments: list[str] | None,
review_threads: list[ReviewThread],
thread_comments: list[str] | None,
):
pass
) -> list[str]:
return []
class GithubPRHandler(GithubIssueHandler):
@ -487,7 +490,7 @@ class GithubPRHandler(GithubIssueHandler):
review_comments: list[str] | None,
review_threads: list[ReviewThread],
thread_comments: list[str] | None,
):
) -> list[str]:
new_issue_references = []
if issue_body:

View File

@ -23,38 +23,38 @@ class GitlabIssueHandler(IssueHandlerInterface):
self.clone_url = self.get_clone_url()
self.headers = self.get_headers()
def set_owner(self, owner: str):
def set_owner(self, owner: str) -> None:
self.owner = owner
def get_headers(self):
def get_headers(self) -> dict[str, str]:
return {
'Authorization': f'Bearer {self.token}',
'Accept': 'application/json',
}
def get_base_url(self):
project_path = quote(f'{self.owner}/{self.repo}', safe="")
def get_base_url(self) -> str:
project_path = quote(f'{self.owner}/{self.repo}', safe='')
return f'https://gitlab.com/api/v4/projects/{project_path}'
def get_authorize_url(self):
def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@gitlab.com/'
def get_branch_url(self, branch_name: str):
def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/repository/branches/{branch_name}'
def get_download_url(self):
def get_download_url(self) -> str:
return f'{self.base_url}/issues'
def get_clone_url(self):
def get_clone_url(self) -> str:
username_and_token = self.token
if self.username:
username_and_token = f'{self.username}:{self.token}'
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
def get_graphql_url(self):
def get_graphql_url(self) -> str:
return 'https://gitlab.com/api/graphql'
def get_compare_url(self, branch_name: str):
def get_compare_url(self, branch_name: str) -> str:
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
def get_converted_issues(
@ -189,7 +189,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
print(f'Branch {branch_name} exists: {exists}')
return exists
def get_branch_name(self, base_branch_name: str):
def get_branch_name(self, base_branch_name: str) -> str:
branch_name = base_branch_name
attempt = 1
while self.branch_exists(branch_name):
@ -197,7 +197,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
branch_name = f'{base_branch_name}-try{attempt}'
return branch_name
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
response = requests.get(
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
headers=self.headers,
@ -216,7 +216,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
)
response.raise_for_status()
def get_pull_url(self, pr_number: int):
def get_pull_url(self, pr_number: int) -> str:
return (
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
)
@ -224,9 +224,12 @@ class GitlabIssueHandler(IssueHandlerInterface):
def get_default_branch_name(self) -> str:
response = requests.get(f'{self.base_url}', headers=self.headers)
response.raise_for_status()
return response.json()['default_branch']
data = response.json()
return str(data['default_branch'])
def create_pull_request(self, data=dict) -> dict:
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
response = requests.post(
f'{self.base_url}/merge_requests', headers=self.headers, json=data
)
@ -243,9 +246,9 @@ class GitlabIssueHandler(IssueHandlerInterface):
if 'iid' in pr_data:
pr_data['number'] = pr_data['iid']
return pr_data
return dict(pr_data)
def request_reviewers(self, reviewer: str, pr_number: int):
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
response = requests.get(
f'https://gitlab.com/api/v4/users?username={reviewer}',
headers=self.headers,
@ -264,7 +267,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
)
def send_comment_msg(self, issue_number: int, msg: str):
def send_comment_msg(self, issue_number: int, msg: str) -> None:
"""Send a comment message to a GitHub issue or pull request.
Args:
@ -292,8 +295,8 @@ class GitlabIssueHandler(IssueHandlerInterface):
review_comments: list[str] | None,
review_threads: list[ReviewThread],
thread_comments: list[str] | None,
):
pass
) -> list[str]:
return []
class GitlabPRHandler(GitlabIssueHandler):
@ -479,7 +482,7 @@ class GitlabPRHandler(GitlabIssueHandler):
review_comments: list[str] | None,
review_threads: list[ReviewThread],
thread_comments: list[str] | None,
):
) -> list[str]:
new_issue_references = []
if issue_body:

View File

@ -26,7 +26,7 @@ class Issue(BaseModel):
class IssueHandlerInterface(ABC):
@abstractmethod
def set_owner(self, owner: str):
def set_owner(self, owner: str) -> None:
pass
@abstractmethod
@ -40,43 +40,43 @@ class IssueHandlerInterface(ABC):
pass
@abstractmethod
def get_base_url(self):
def get_base_url(self) -> str:
pass
@abstractmethod
def get_branch_url(self, branch_name):
def get_branch_url(self, branch_name: str) -> str:
pass
@abstractmethod
def get_download_url(self):
def get_download_url(self) -> str:
pass
@abstractmethod
def get_clone_url(self):
def get_clone_url(self) -> str:
pass
@abstractmethod
def get_pull_url(self, pr_number: int):
def get_pull_url(self, pr_number: int) -> str:
pass
@abstractmethod
def get_graphql_url(self):
def get_graphql_url(self) -> str:
pass
@abstractmethod
def get_headers(self):
def get_headers(self) -> dict[str, str]:
pass
@abstractmethod
def get_compare_url(self, branch_name):
def get_compare_url(self, branch_name: str) -> str:
pass
@abstractmethod
def get_branch_name(self, base_branch_name: str):
def get_branch_name(self, base_branch_name: str) -> str:
pass
@abstractmethod
def get_default_branch_name(self):
def get_default_branch_name(self) -> str:
pass
@abstractmethod
@ -84,23 +84,25 @@ class IssueHandlerInterface(ABC):
pass
@abstractmethod
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
pass
@abstractmethod
def send_comment_msg(self, issue_number: int, msg: str):
def send_comment_msg(self, issue_number: int, msg: str) -> None:
pass
@abstractmethod
def get_authorize_url(self):
def get_authorize_url(self) -> str:
pass
@abstractmethod
def create_pull_request(self, data=dict) -> dict:
pass
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
raise NotImplementedError
@abstractmethod
def request_reviewers(self, reviewer: str, pr_number: int):
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
pass
@abstractmethod
@ -112,7 +114,7 @@ class IssueHandlerInterface(ABC):
review_comments: list[str] | None,
review_threads: list[ReviewThread],
thread_comments: list[str] | None,
):
) -> list[str]:
pass
@abstractmethod

View File

@ -25,7 +25,7 @@ class ServiceContext:
if llm_config is not None:
self.llm = LLM(llm_config)
def set_strategy(self, strategy):
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
self._strategy = strategy
@ -36,7 +36,7 @@ class ServiceContextPR(ServiceContext):
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig):
super().__init__(strategy, llm_config)
def get_clone_url(self):
def get_clone_url(self) -> str:
return self._strategy.get_clone_url()
def download_issues(self) -> list[Any]:
@ -266,31 +266,31 @@ class ServiceContextIssue(ServiceContext):
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
super().__init__(strategy, llm_config)
def get_base_url(self):
def get_base_url(self) -> str:
return self._strategy.get_base_url()
def get_branch_url(self, branch_name):
def get_branch_url(self, branch_name: str) -> str:
return self._strategy.get_branch_url(branch_name)
def get_download_url(self):
def get_download_url(self) -> str:
return self._strategy.get_download_url()
def get_clone_url(self):
def get_clone_url(self) -> str:
return self._strategy.get_clone_url()
def get_graphql_url(self):
def get_graphql_url(self) -> str:
return self._strategy.get_graphql_url()
def get_headers(self):
def get_headers(self) -> dict[str, str]:
return self._strategy.get_headers()
def get_authorize_url(self):
def get_authorize_url(self) -> str:
return self._strategy.get_authorize_url()
def get_pull_url(self, pr_number: int):
def get_pull_url(self, pr_number: int) -> str:
return self._strategy.get_pull_url(pr_number)
def get_compare_url(self, branch_name: str):
def get_compare_url(self, branch_name: str) -> str:
return self._strategy.get_compare_url(branch_name)
def download_issues(self) -> list[Any]:
@ -299,25 +299,27 @@ class ServiceContextIssue(ServiceContext):
def get_branch_name(
self,
base_branch_name: str,
):
) -> str:
return self._strategy.get_branch_name(base_branch_name)
def branch_exists(self, branch_name: str):
def branch_exists(self, branch_name: str) -> bool:
return self._strategy.branch_exists(branch_name)
def get_default_branch_name(self) -> str:
return self._strategy.get_default_branch_name()
def create_pull_request(self, data=dict):
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
return self._strategy.create_pull_request(data)
def request_reviewers(self, reviewer: str, pr_number: int):
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
return self._strategy.request_reviewers(reviewer, pr_number)
def reply_to_comment(self, pr_number, comment_id, reply):
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
return self._strategy.reply_to_comment(pr_number, comment_id, reply)
def send_comment_msg(self, issue_number: int, msg: str):
def send_comment_msg(self, issue_number: int, msg: str) -> None:
return self._strategy.send_comment_msg(issue_number, msg)
def get_issue_comments(

View File

@ -5,10 +5,13 @@ import subprocess
import tempfile
from .exceptions import HunkApplyException, SubprocessException
from .patch import Change, diffobj
from .snippets import remove, which
def _apply_diff_with_subprocess(diff, lines, reverse=False):
def _apply_diff_with_subprocess(
diff: diffobj, lines: list[str], reverse: bool = False
) -> tuple[list[str], list[str] | None]:
# call out to patch program
patchexec = which('patch')
if not patchexec:
@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False):
return lines, rejlines
def _reverse(changes):
def _reverse_change(c):
def _reverse(changes: list[Change]) -> list[Change]:
def _reverse_change(c: Change) -> Change:
return c._replace(old=c.new, new=c.old)
return [_reverse_change(c) for c in changes]
def apply_diff(diff, text, reverse=False, use_patch=False):
try:
lines = text.splitlines()
except AttributeError:
lines = list(text)
def apply_diff(
diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False
) -> list[str]:
lines = text.splitlines() if isinstance(text, str) else list(text)
if use_patch:
return _apply_diff_with_subprocess(diff, lines, reverse)
lines, _ = _apply_diff_with_subprocess(diff, lines, reverse)
return lines
n_lines = len(lines)

View File

@ -1,31 +1,31 @@
class PatchingException(Exception):
pass
class HunkException(PatchingException):
def __init__(self, msg, hunk=None):
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
else:
super(HunkException, self).__init__(msg)
class ApplyException(PatchingException):
pass
class SubprocessException(ApplyException):
def __init__(self, msg, code):
super(SubprocessException, self).__init__(msg)
self.code = code
class HunkApplyException(HunkException, ApplyException, ValueError):
pass
class ParseException(HunkException, ValueError):
pass
class PatchingException(Exception):
pass
class HunkException(PatchingException):
def __init__(self, msg: str, hunk: int | None = None) -> None:
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
else:
super(HunkException, self).__init__(msg)
class ApplyException(PatchingException):
pass
class SubprocessException(ApplyException):
def __init__(self, msg: str, code: int) -> None:
super(SubprocessException, self).__init__(msg)
self.code = code
class HunkApplyException(HunkException, ApplyException, ValueError):
pass
class ParseException(HunkException, ValueError):
pass

View File

@ -3,6 +3,7 @@ import base64
import re
import zlib
from collections import namedtuple
from typing import Iterable
from . import exceptions
from .snippets import findall_regex, split_by_regex
@ -71,11 +72,8 @@ cvs_header_timestamp_colon = re.compile(r':([\d.]+)\t(.+)')
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
def parse_patch(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_patch(text: str | list[str]) -> Iterable[diffobj]:
lines = text.splitlines() if isinstance(text, str) else text
# maybe use this to nuke all of those line endings?
# lines = [x.splitlines()[0] for x in lines]
@ -104,18 +102,15 @@ def parse_patch(text):
yield diffobj(header=h, changes=d, text=difftext)
def parse_header(text):
def parse_header(text: str | list[str]) -> header | None:
h = parse_scm_header(text)
if h is None:
h = parse_diff_header(text)
return h
def parse_scm_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_scm_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
check = [
(git_header_index, parse_git_header),
@ -154,11 +149,8 @@ def parse_scm_header(text):
return None
def parse_diff_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_diff_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
check = [
(unified_header_new_line, parse_unified_header),
@ -178,10 +170,10 @@ def parse_diff_header(text):
return None # no header?
def parse_diff(text):
try:
def parse_diff(text: str | list[str]) -> list[Change] | None:
if isinstance(text, str):
lines = text.splitlines()
except AttributeError:
else:
lines = text
check = [
@ -200,11 +192,8 @@ def parse_diff(text):
return None
def parse_git_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_git_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
old_version = None
new_version = None
@ -275,11 +264,8 @@ def parse_git_header(text):
return None
def parse_svn_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_svn_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, svn_header_index)
if len(headers) == 0:
@ -346,11 +332,8 @@ def parse_svn_header(text):
return None
def parse_cvs_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_cvs_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, cvs_header_rcs)
headers_old = findall_regex(lines, old_cvs_diffcmd_header)
@ -430,11 +413,8 @@ def parse_cvs_header(text):
return None
def parse_diffcmd_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_diffcmd_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, diffcmd_header)
if len(headers) == 0:
@ -454,11 +434,8 @@ def parse_diffcmd_header(text):
return None
def parse_unified_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_unified_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, unified_header_new_line)
if len(headers) == 0:
@ -490,11 +467,8 @@ def parse_unified_header(text):
return None
def parse_context_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_context_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, context_header_old_line)
if len(headers) == 0:
@ -526,11 +500,8 @@ def parse_context_header(text):
return None
def parse_default_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_default_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@ -582,11 +553,8 @@ def parse_default_diff(text):
return None
def parse_unified_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_unified_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@ -652,11 +620,8 @@ def parse_unified_diff(text):
return None
def parse_context_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_context_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@ -795,11 +760,8 @@ def parse_context_diff(text):
return None
def parse_ed_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_ed_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
j = 0
@ -878,12 +840,9 @@ def parse_ed_diff(text):
return None
def parse_rcs_ed_diff(text):
def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None:
# much like forward ed, but no 'c' type
try:
lines = text.splitlines()
except AttributeError:
lines = text
lines = text.splitlines() if isinstance(text, str) else text
old = 0
j = 0
@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text):
hunk_kind = o.group(1)
old = int(o.group(2))
size = int(o.group(3))
size = int(o.group(3)) if o.group(3) else 0
if hunk_kind == 'a':
old += total_change_size + 1
@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text):
if len(changes) > 0:
return changes
return None
def parse_git_binary_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
changes: list[Change] = list()

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
import os
import re
from shutil import rmtree
def remove(path):
def remove(path: str) -> None:
if os.path.exists(path):
if os.path.isdir(path):
rmtree(path)
@ -13,7 +14,7 @@ def remove(path):
# find all indices of a list of strings that match a regex
def findall_regex(items, regex):
def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]:
found = list()
for i in range(0, len(items)):
k = regex.match(items[i])
@ -24,7 +25,7 @@ def findall_regex(items, regex):
return found
def split_by_regex(items, regex):
def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]:
splits = list()
indices = findall_regex(items, regex)
if not indices:
@ -45,8 +46,8 @@ def split_by_regex(items, regex):
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
def which(program):
def is_exe(fpath):
def which(program: str) -> str | None:
def is_exe(fpath: str) -> bool:
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, fname = os.path.split(program)

View File

@ -6,8 +6,9 @@ import multiprocessing as mp
import os
import pathlib
import subprocess
from typing import Awaitable, TextIO
from typing import Any, Awaitable, TextIO
from pydantic import SecretStr
from tqdm import tqdm
import openhands
@ -25,7 +26,7 @@ from openhands.resolver.utils import (
)
def cleanup():
def cleanup() -> None:
print('Cleaning up child processes...')
for process in mp.active_children():
print(f'Terminating child process: {process.name}')
@ -214,7 +215,7 @@ async def resolve_issues(
# Use asyncio.gather with a semaphore to limit concurrency
sem = asyncio.Semaphore(num_workers)
async def run_with_semaphore(task):
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
async with sem:
return await task
@ -228,7 +229,7 @@ async def resolve_issues(
logger.info('Finished.')
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Resolve multiple issues from Github or Gitlab.'
)
@ -349,7 +350,7 @@ def main():
llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

View File

@ -10,6 +10,7 @@ import subprocess
from typing import Any
from uuid import uuid4
from pydantic import SecretStr
from termcolor import colored
import openhands
@ -18,6 +19,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig, SandboxConf
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.event import Event
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
@ -48,7 +50,7 @@ AGENT_CLASS = 'CodeActAgent'
def initialize_runtime(
runtime: Runtime,
platform: Platform,
):
) -> None:
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
@ -192,26 +194,28 @@ async def process_issue(
# This code looks unnecessary because these are default values in the config class
# they're set by default if nothing else overrides them
# FIXME we should remove them here
kwargs = {}
sandbox_config = SandboxConfig(
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
)
if os.getenv('GITLAB_CI') == 'True':
kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost')
sandbox_config.local_runtime_url = os.getenv(
'LOCAL_RUNTIME_URL', 'http://localhost'
)
user_id = os.getuid() if hasattr(os, 'getuid') else 1000
if user_id == 0:
kwargs['user_id'] = get_unique_uid()
sandbox_config.user_id = get_unique_uid()
config = AppConfig(
default_agent='CodeActAgent',
runtime='docker',
max_budget_per_task=4,
max_iterations=max_iterations,
sandbox=SandboxConfig(
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
**kwargs,
),
sandbox=sandbox_config,
# do not mount workspace
workspace_base=workspace_base,
workspace_mount_path=workspace_base,
@ -222,7 +226,7 @@ async def process_issue(
runtime = create_runtime(config)
await runtime.connect()
def on_event(evt):
def on_event(evt: Event) -> None:
logger.info(evt)
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
@ -524,10 +528,10 @@ async def resolve_issue(
logger.info('Finished.')
def main():
def main() -> None:
import argparse
def int_or_none(value):
def int_or_none(value: str) -> int | None:
if value.lower() == 'none':
return None
else:
@ -654,7 +658,7 @@ def main():
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

View File

@ -5,6 +5,7 @@ import shutil
import subprocess
import jinja2
from pydantic import SecretStr
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
@ -543,7 +544,7 @@ def process_all_successful_issues(
)
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Send a pull request to Github or Gitlab.'
)
@ -641,7 +642,7 @@ def main():
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

View File

@ -107,7 +107,7 @@ def codeact_user_response(
return msg
def cleanup():
def cleanup() -> None:
print('Cleaning up child processes...')
for process in mp.active_children():
print(f'Terminating child process: {process.name}')
@ -115,7 +115,9 @@ def cleanup():
process.join()
def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
def prepare_dataset(
dataset: pd.DataFrame, output_file: str, eval_n_limit: int
) -> pd.DataFrame:
assert 'instance_id' in dataset.columns, (
"Expected 'instance_id' column in the dataset. You should define your own "
"unique identifier for each instance and use it as the 'instance_id' column."
@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
def reset_logger_for_multiprocessing(
logger: logging.Logger, instance_id: str, log_dir: str
):
) -> None:
"""Reset the logger for multiprocessing.
Save logs to a separate file for each process, instead of trying to write to the
@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]:
return [int(match) for match in re.findall(pattern, body)]
def get_unique_uid(start_uid=1000):
def get_unique_uid(start_uid: int = 1000) -> int:
existing_uids = set()
with open('/etc/passwd', 'r') as passwd_file:
for line in passwd_file:

View File

@ -4,7 +4,9 @@ import os
from openhands.resolver.io_utils import load_single_resolver_output
def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str):
def visualize_resolver_output(
issue_number: int, output_dir: str, vis_method: str
) -> None:
output_jsonl = os.path.join(output_dir, 'output.jsonl')
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
if vis_method == 'json':