mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
Fix type checking errors in resolver directory (#6738)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
1a7003a705
commit
f4e5fb2873
@ -22,28 +22,28 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
self.clone_url = self.get_clone_url()
|
||||
self.headers = self.get_headers()
|
||||
|
||||
def set_owner(self, owner: str):
|
||||
def set_owner(self, owner: str) -> None:
|
||||
self.owner = owner
|
||||
|
||||
def get_headers(self):
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
'Authorization': f'token {self.token}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
def get_base_url(self) -> str:
|
||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||
|
||||
def get_authorize_url(self):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f'https://{self.username}:{self.token}@github.com/'
|
||||
|
||||
def get_branch_url(self, branch_name: str):
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
return self.get_base_url() + f'/branches/{branch_name}'
|
||||
|
||||
def get_download_url(self):
|
||||
def get_download_url(self) -> str:
|
||||
return f'{self.base_url}/issues'
|
||||
|
||||
def get_clone_url(self):
|
||||
def get_clone_url(self) -> str:
|
||||
username_and_token = (
|
||||
f'{self.username}:{self.token}'
|
||||
if self.username
|
||||
@ -51,10 +51,10 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
)
|
||||
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self):
|
||||
def get_graphql_url(self) -> str:
|
||||
return 'https://api.github.com/graphql'
|
||||
|
||||
def get_compare_url(self, branch_name: str):
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
|
||||
|
||||
def get_converted_issues(
|
||||
@ -186,7 +186,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
print(f'Branch {branch_name} exists: {exists}')
|
||||
return exists
|
||||
|
||||
def get_branch_name(self, base_branch_name: str):
|
||||
def get_branch_name(self, base_branch_name: str) -> str:
|
||||
branch_name = base_branch_name
|
||||
attempt = 1
|
||||
while self.branch_exists(branch_name):
|
||||
@ -194,7 +194,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
branch_name = f'{base_branch_name}-try{attempt}'
|
||||
return branch_name
|
||||
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
# Opting for graphql as REST API doesn't allow reply to replies in comment threads
|
||||
query = """
|
||||
mutation($body: String!, $pullRequestReviewThreadId: ID!) {
|
||||
@ -221,15 +221,18 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_pull_url(self, pr_number: int):
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
|
||||
|
||||
def get_default_branch_name(self) -> str:
|
||||
response = requests.get(f'{self.base_url}', headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()['default_branch']
|
||||
data = response.json()
|
||||
return str(data['default_branch'])
|
||||
|
||||
def create_pull_request(self, data=dict) -> dict:
|
||||
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
if data is None:
|
||||
data = {}
|
||||
response = requests.post(
|
||||
f'{self.base_url}/pulls', headers=self.headers, json=data
|
||||
)
|
||||
@ -240,9 +243,9 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
)
|
||||
response.raise_for_status()
|
||||
pr_data = response.json()
|
||||
return pr_data
|
||||
return dict(pr_data)
|
||||
|
||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
||||
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||
review_data = {'reviewers': [reviewer]}
|
||||
review_response = requests.post(
|
||||
f'{self.base_url}/pulls/{pr_number}/requested_reviewers',
|
||||
@ -254,7 +257,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
||||
)
|
||||
|
||||
def send_comment_msg(self, issue_number: int, msg: str):
|
||||
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||
"""Send a comment message to a GitHub issue or pull request.
|
||||
|
||||
Args:
|
||||
@ -282,8 +285,8 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
review_comments: list[str] | None,
|
||||
review_threads: list[ReviewThread],
|
||||
thread_comments: list[str] | None,
|
||||
):
|
||||
pass
|
||||
) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
class GithubPRHandler(GithubIssueHandler):
|
||||
@ -487,7 +490,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
review_comments: list[str] | None,
|
||||
review_threads: list[ReviewThread],
|
||||
thread_comments: list[str] | None,
|
||||
):
|
||||
) -> list[str]:
|
||||
new_issue_references = []
|
||||
|
||||
if issue_body:
|
||||
|
||||
@ -23,38 +23,38 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
self.clone_url = self.get_clone_url()
|
||||
self.headers = self.get_headers()
|
||||
|
||||
def set_owner(self, owner: str):
|
||||
def set_owner(self, owner: str) -> None:
|
||||
self.owner = owner
|
||||
|
||||
def get_headers(self):
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
'Authorization': f'Bearer {self.token}',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
project_path = quote(f'{self.owner}/{self.repo}', safe="")
|
||||
def get_base_url(self) -> str:
|
||||
project_path = quote(f'{self.owner}/{self.repo}', safe='')
|
||||
return f'https://gitlab.com/api/v4/projects/{project_path}'
|
||||
|
||||
def get_authorize_url(self):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f'https://{self.username}:{self.token}@gitlab.com/'
|
||||
|
||||
def get_branch_url(self, branch_name: str):
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
return self.get_base_url() + f'/repository/branches/{branch_name}'
|
||||
|
||||
def get_download_url(self):
|
||||
def get_download_url(self) -> str:
|
||||
return f'{self.base_url}/issues'
|
||||
|
||||
def get_clone_url(self):
|
||||
def get_clone_url(self) -> str:
|
||||
username_and_token = self.token
|
||||
if self.username:
|
||||
username_and_token = f'{self.username}:{self.token}'
|
||||
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self):
|
||||
def get_graphql_url(self) -> str:
|
||||
return 'https://gitlab.com/api/graphql'
|
||||
|
||||
def get_compare_url(self, branch_name: str):
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
|
||||
|
||||
def get_converted_issues(
|
||||
@ -189,7 +189,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
print(f'Branch {branch_name} exists: {exists}')
|
||||
return exists
|
||||
|
||||
def get_branch_name(self, base_branch_name: str):
|
||||
def get_branch_name(self, base_branch_name: str) -> str:
|
||||
branch_name = base_branch_name
|
||||
attempt = 1
|
||||
while self.branch_exists(branch_name):
|
||||
@ -197,7 +197,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
branch_name = f'{base_branch_name}-try{attempt}'
|
||||
return branch_name
|
||||
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
response = requests.get(
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
|
||||
headers=self.headers,
|
||||
@ -216,7 +216,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_pull_url(self, pr_number: int):
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
return (
|
||||
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
|
||||
)
|
||||
@ -224,9 +224,12 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
def get_default_branch_name(self) -> str:
|
||||
response = requests.get(f'{self.base_url}', headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()['default_branch']
|
||||
data = response.json()
|
||||
return str(data['default_branch'])
|
||||
|
||||
def create_pull_request(self, data=dict) -> dict:
|
||||
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
if data is None:
|
||||
data = {}
|
||||
response = requests.post(
|
||||
f'{self.base_url}/merge_requests', headers=self.headers, json=data
|
||||
)
|
||||
@ -243,9 +246,9 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
if 'iid' in pr_data:
|
||||
pr_data['number'] = pr_data['iid']
|
||||
|
||||
return pr_data
|
||||
return dict(pr_data)
|
||||
|
||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
||||
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||
response = requests.get(
|
||||
f'https://gitlab.com/api/v4/users?username={reviewer}',
|
||||
headers=self.headers,
|
||||
@ -264,7 +267,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
||||
)
|
||||
|
||||
def send_comment_msg(self, issue_number: int, msg: str):
|
||||
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||
"""Send a comment message to a GitHub issue or pull request.
|
||||
|
||||
Args:
|
||||
@ -292,8 +295,8 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
review_comments: list[str] | None,
|
||||
review_threads: list[ReviewThread],
|
||||
thread_comments: list[str] | None,
|
||||
):
|
||||
pass
|
||||
) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
class GitlabPRHandler(GitlabIssueHandler):
|
||||
@ -479,7 +482,7 @@ class GitlabPRHandler(GitlabIssueHandler):
|
||||
review_comments: list[str] | None,
|
||||
review_threads: list[ReviewThread],
|
||||
thread_comments: list[str] | None,
|
||||
):
|
||||
) -> list[str]:
|
||||
new_issue_references = []
|
||||
|
||||
if issue_body:
|
||||
|
||||
@ -26,7 +26,7 @@ class Issue(BaseModel):
|
||||
|
||||
class IssueHandlerInterface(ABC):
|
||||
@abstractmethod
|
||||
def set_owner(self, owner: str):
|
||||
def set_owner(self, owner: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -40,43 +40,43 @@ class IssueHandlerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_url(self):
|
||||
def get_base_url(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_branch_url(self, branch_name):
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_download_url(self):
|
||||
def get_download_url(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_clone_url(self):
|
||||
def get_clone_url(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pull_url(self, pr_number: int):
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_graphql_url(self):
|
||||
def get_graphql_url(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_headers(self):
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_compare_url(self, branch_name):
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_branch_name(self, base_branch_name: str):
|
||||
def get_branch_name(self, base_branch_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_branch_name(self):
|
||||
def get_default_branch_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -84,23 +84,25 @@ class IssueHandlerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_comment_msg(self, issue_number: int, msg: str):
|
||||
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_authorize_url(self):
|
||||
def get_authorize_url(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_pull_request(self, data=dict) -> dict:
|
||||
pass
|
||||
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
if data is None:
|
||||
data = {}
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
||||
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -112,7 +114,7 @@ class IssueHandlerInterface(ABC):
|
||||
review_comments: list[str] | None,
|
||||
review_threads: list[ReviewThread],
|
||||
thread_comments: list[str] | None,
|
||||
):
|
||||
) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -25,7 +25,7 @@ class ServiceContext:
|
||||
if llm_config is not None:
|
||||
self.llm = LLM(llm_config)
|
||||
|
||||
def set_strategy(self, strategy):
|
||||
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ class ServiceContextPR(ServiceContext):
|
||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig):
|
||||
super().__init__(strategy, llm_config)
|
||||
|
||||
def get_clone_url(self):
|
||||
def get_clone_url(self) -> str:
|
||||
return self._strategy.get_clone_url()
|
||||
|
||||
def download_issues(self) -> list[Any]:
|
||||
@ -266,31 +266,31 @@ class ServiceContextIssue(ServiceContext):
|
||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
||||
super().__init__(strategy, llm_config)
|
||||
|
||||
def get_base_url(self):
|
||||
def get_base_url(self) -> str:
|
||||
return self._strategy.get_base_url()
|
||||
|
||||
def get_branch_url(self, branch_name):
|
||||
def get_branch_url(self, branch_name: str) -> str:
|
||||
return self._strategy.get_branch_url(branch_name)
|
||||
|
||||
def get_download_url(self):
|
||||
def get_download_url(self) -> str:
|
||||
return self._strategy.get_download_url()
|
||||
|
||||
def get_clone_url(self):
|
||||
def get_clone_url(self) -> str:
|
||||
return self._strategy.get_clone_url()
|
||||
|
||||
def get_graphql_url(self):
|
||||
def get_graphql_url(self) -> str:
|
||||
return self._strategy.get_graphql_url()
|
||||
|
||||
def get_headers(self):
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
return self._strategy.get_headers()
|
||||
|
||||
def get_authorize_url(self):
|
||||
def get_authorize_url(self) -> str:
|
||||
return self._strategy.get_authorize_url()
|
||||
|
||||
def get_pull_url(self, pr_number: int):
|
||||
def get_pull_url(self, pr_number: int) -> str:
|
||||
return self._strategy.get_pull_url(pr_number)
|
||||
|
||||
def get_compare_url(self, branch_name: str):
|
||||
def get_compare_url(self, branch_name: str) -> str:
|
||||
return self._strategy.get_compare_url(branch_name)
|
||||
|
||||
def download_issues(self) -> list[Any]:
|
||||
@ -299,25 +299,27 @@ class ServiceContextIssue(ServiceContext):
|
||||
def get_branch_name(
|
||||
self,
|
||||
base_branch_name: str,
|
||||
):
|
||||
) -> str:
|
||||
return self._strategy.get_branch_name(base_branch_name)
|
||||
|
||||
def branch_exists(self, branch_name: str):
|
||||
def branch_exists(self, branch_name: str) -> bool:
|
||||
return self._strategy.branch_exists(branch_name)
|
||||
|
||||
def get_default_branch_name(self) -> str:
|
||||
return self._strategy.get_default_branch_name()
|
||||
|
||||
def create_pull_request(self, data=dict):
|
||||
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
if data is None:
|
||||
data = {}
|
||||
return self._strategy.create_pull_request(data)
|
||||
|
||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
||||
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||
return self._strategy.request_reviewers(reviewer, pr_number)
|
||||
|
||||
def reply_to_comment(self, pr_number, comment_id, reply):
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
return self._strategy.reply_to_comment(pr_number, comment_id, reply)
|
||||
|
||||
def send_comment_msg(self, issue_number: int, msg: str):
|
||||
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||
return self._strategy.send_comment_msg(issue_number, msg)
|
||||
|
||||
def get_issue_comments(
|
||||
|
||||
@ -5,10 +5,13 @@ import subprocess
|
||||
import tempfile
|
||||
|
||||
from .exceptions import HunkApplyException, SubprocessException
|
||||
from .patch import Change, diffobj
|
||||
from .snippets import remove, which
|
||||
|
||||
|
||||
def _apply_diff_with_subprocess(diff, lines, reverse=False):
|
||||
def _apply_diff_with_subprocess(
|
||||
diff: diffobj, lines: list[str], reverse: bool = False
|
||||
) -> tuple[list[str], list[str] | None]:
|
||||
# call out to patch program
|
||||
patchexec = which('patch')
|
||||
if not patchexec:
|
||||
@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False):
|
||||
return lines, rejlines
|
||||
|
||||
|
||||
def _reverse(changes):
|
||||
def _reverse_change(c):
|
||||
def _reverse(changes: list[Change]) -> list[Change]:
|
||||
def _reverse_change(c: Change) -> Change:
|
||||
return c._replace(old=c.new, new=c.old)
|
||||
|
||||
return [_reverse_change(c) for c in changes]
|
||||
|
||||
|
||||
def apply_diff(diff, text, reverse=False, use_patch=False):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = list(text)
|
||||
def apply_diff(
|
||||
diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False
|
||||
) -> list[str]:
|
||||
lines = text.splitlines() if isinstance(text, str) else list(text)
|
||||
|
||||
if use_patch:
|
||||
return _apply_diff_with_subprocess(diff, lines, reverse)
|
||||
lines, _ = _apply_diff_with_subprocess(diff, lines, reverse)
|
||||
return lines
|
||||
|
||||
n_lines = len(lines)
|
||||
|
||||
|
||||
@ -1,31 +1,31 @@
|
||||
class PatchingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HunkException(PatchingException):
|
||||
def __init__(self, msg, hunk=None):
|
||||
self.hunk = hunk
|
||||
if hunk is not None:
|
||||
super(HunkException, self).__init__(
|
||||
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
||||
)
|
||||
else:
|
||||
super(HunkException, self).__init__(msg)
|
||||
|
||||
|
||||
class ApplyException(PatchingException):
|
||||
pass
|
||||
|
||||
|
||||
class SubprocessException(ApplyException):
|
||||
def __init__(self, msg, code):
|
||||
super(SubprocessException, self).__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
class HunkApplyException(HunkException, ApplyException, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ParseException(HunkException, ValueError):
|
||||
pass
|
||||
class PatchingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HunkException(PatchingException):
|
||||
def __init__(self, msg: str, hunk: int | None = None) -> None:
|
||||
self.hunk = hunk
|
||||
if hunk is not None:
|
||||
super(HunkException, self).__init__(
|
||||
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
||||
)
|
||||
else:
|
||||
super(HunkException, self).__init__(msg)
|
||||
|
||||
|
||||
class ApplyException(PatchingException):
|
||||
pass
|
||||
|
||||
|
||||
class SubprocessException(ApplyException):
|
||||
def __init__(self, msg: str, code: int) -> None:
|
||||
super(SubprocessException, self).__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
class HunkApplyException(HunkException, ApplyException, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ParseException(HunkException, ValueError):
|
||||
pass
|
||||
|
||||
@ -3,6 +3,7 @@ import base64
|
||||
import re
|
||||
import zlib
|
||||
from collections import namedtuple
|
||||
from typing import Iterable
|
||||
|
||||
from . import exceptions
|
||||
from .snippets import findall_regex, split_by_regex
|
||||
@ -71,11 +72,8 @@ cvs_header_timestamp_colon = re.compile(r':([\d.]+)\t(.+)')
|
||||
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
|
||||
|
||||
|
||||
def parse_patch(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_patch(text: str | list[str]) -> Iterable[diffobj]:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
# maybe use this to nuke all of those line endings?
|
||||
# lines = [x.splitlines()[0] for x in lines]
|
||||
@ -104,18 +102,15 @@ def parse_patch(text):
|
||||
yield diffobj(header=h, changes=d, text=difftext)
|
||||
|
||||
|
||||
def parse_header(text):
|
||||
def parse_header(text: str | list[str]) -> header | None:
|
||||
h = parse_scm_header(text)
|
||||
if h is None:
|
||||
h = parse_diff_header(text)
|
||||
return h
|
||||
|
||||
|
||||
def parse_scm_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_scm_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
check = [
|
||||
(git_header_index, parse_git_header),
|
||||
@ -154,11 +149,8 @@ def parse_scm_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_diff_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_diff_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
check = [
|
||||
(unified_header_new_line, parse_unified_header),
|
||||
@ -178,10 +170,10 @@ def parse_diff_header(text):
|
||||
return None # no header?
|
||||
|
||||
|
||||
def parse_diff(text):
|
||||
try:
|
||||
def parse_diff(text: str | list[str]) -> list[Change] | None:
|
||||
if isinstance(text, str):
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
else:
|
||||
lines = text
|
||||
|
||||
check = [
|
||||
@ -200,11 +192,8 @@ def parse_diff(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_git_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_git_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old_version = None
|
||||
new_version = None
|
||||
@ -275,11 +264,8 @@ def parse_git_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_svn_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_svn_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
headers = findall_regex(lines, svn_header_index)
|
||||
if len(headers) == 0:
|
||||
@ -346,11 +332,8 @@ def parse_svn_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_cvs_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_cvs_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
headers = findall_regex(lines, cvs_header_rcs)
|
||||
headers_old = findall_regex(lines, old_cvs_diffcmd_header)
|
||||
@ -430,11 +413,8 @@ def parse_cvs_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_diffcmd_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_diffcmd_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
headers = findall_regex(lines, diffcmd_header)
|
||||
if len(headers) == 0:
|
||||
@ -454,11 +434,8 @@ def parse_diffcmd_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_unified_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_unified_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
headers = findall_regex(lines, unified_header_new_line)
|
||||
if len(headers) == 0:
|
||||
@ -490,11 +467,8 @@ def parse_unified_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_header(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_context_header(text: str | list[str]) -> header | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
headers = findall_regex(lines, context_header_old_line)
|
||||
if len(headers) == 0:
|
||||
@ -526,11 +500,8 @@ def parse_context_header(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_default_diff(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_default_diff(text: str | list[str]) -> list[Change] | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old = 0
|
||||
new = 0
|
||||
@ -582,11 +553,8 @@ def parse_default_diff(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_unified_diff(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_unified_diff(text: str | list[str]) -> list[Change] | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old = 0
|
||||
new = 0
|
||||
@ -652,11 +620,8 @@ def parse_unified_diff(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_diff(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_context_diff(text: str | list[str]) -> list[Change] | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old = 0
|
||||
new = 0
|
||||
@ -795,11 +760,8 @@ def parse_context_diff(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_ed_diff(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_ed_diff(text: str | list[str]) -> list[Change] | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old = 0
|
||||
j = 0
|
||||
@ -878,12 +840,9 @@ def parse_ed_diff(text):
|
||||
return None
|
||||
|
||||
|
||||
def parse_rcs_ed_diff(text):
|
||||
def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None:
|
||||
# much like forward ed, but no 'c' type
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
old = 0
|
||||
j = 0
|
||||
@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text):
|
||||
|
||||
hunk_kind = o.group(1)
|
||||
old = int(o.group(2))
|
||||
size = int(o.group(3))
|
||||
size = int(o.group(3)) if o.group(3) else 0
|
||||
|
||||
if hunk_kind == 'a':
|
||||
old += total_change_size + 1
|
||||
@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text):
|
||||
|
||||
if len(changes) > 0:
|
||||
return changes
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_git_binary_diff(text):
|
||||
try:
|
||||
lines = text.splitlines()
|
||||
except AttributeError:
|
||||
lines = text
|
||||
def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None:
|
||||
lines = text.splitlines() if isinstance(text, str) else text
|
||||
|
||||
changes: list[Change] = list()
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
from shutil import rmtree
|
||||
|
||||
|
||||
def remove(path):
|
||||
def remove(path: str) -> None:
|
||||
if os.path.exists(path):
|
||||
if os.path.isdir(path):
|
||||
rmtree(path)
|
||||
@ -13,7 +14,7 @@ def remove(path):
|
||||
|
||||
|
||||
# find all indices of a list of strings that match a regex
|
||||
def findall_regex(items, regex):
|
||||
def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]:
|
||||
found = list()
|
||||
for i in range(0, len(items)):
|
||||
k = regex.match(items[i])
|
||||
@ -24,7 +25,7 @@ def findall_regex(items, regex):
|
||||
return found
|
||||
|
||||
|
||||
def split_by_regex(items, regex):
|
||||
def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]:
|
||||
splits = list()
|
||||
indices = findall_regex(items, regex)
|
||||
if not indices:
|
||||
@ -45,8 +46,8 @@ def split_by_regex(items, regex):
|
||||
|
||||
|
||||
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
|
||||
def which(program):
|
||||
def is_exe(fpath):
|
||||
def which(program: str) -> str | None:
|
||||
def is_exe(fpath: str) -> bool:
|
||||
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
||||
|
||||
fpath, fname = os.path.split(program)
|
||||
|
||||
@ -6,8 +6,9 @@ import multiprocessing as mp
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from typing import Awaitable, TextIO
|
||||
from typing import Any, Awaitable, TextIO
|
||||
|
||||
from pydantic import SecretStr
|
||||
from tqdm import tqdm
|
||||
|
||||
import openhands
|
||||
@ -25,7 +26,7 @@ from openhands.resolver.utils import (
|
||||
)
|
||||
|
||||
|
||||
def cleanup():
|
||||
def cleanup() -> None:
|
||||
print('Cleaning up child processes...')
|
||||
for process in mp.active_children():
|
||||
print(f'Terminating child process: {process.name}')
|
||||
@ -214,7 +215,7 @@ async def resolve_issues(
|
||||
# Use asyncio.gather with a semaphore to limit concurrency
|
||||
sem = asyncio.Semaphore(num_workers)
|
||||
|
||||
async def run_with_semaphore(task):
|
||||
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
|
||||
async with sem:
|
||||
return await task
|
||||
|
||||
@ -228,7 +229,7 @@ async def resolve_issues(
|
||||
logger.info('Finished.')
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Resolve multiple issues from Github or Gitlab.'
|
||||
)
|
||||
@ -349,7 +350,7 @@ def main():
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import subprocess
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from termcolor import colored
|
||||
|
||||
import openhands
|
||||
@ -18,6 +19,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig, SandboxConf
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
@ -48,7 +50,7 @@ AGENT_CLASS = 'CodeActAgent'
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
platform: Platform,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
@ -192,26 +194,28 @@ async def process_issue(
|
||||
# This code looks unnecessary because these are default values in the config class
|
||||
# they're set by default if nothing else overrides them
|
||||
# FIXME we should remove them here
|
||||
kwargs = {}
|
||||
sandbox_config = SandboxConfig(
|
||||
runtime_container_image=runtime_container_image,
|
||||
enable_auto_lint=False,
|
||||
use_host_network=False,
|
||||
# large enough timeout, since some testcases take very long to run
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
if os.getenv('GITLAB_CI') == 'True':
|
||||
kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost')
|
||||
sandbox_config.local_runtime_url = os.getenv(
|
||||
'LOCAL_RUNTIME_URL', 'http://localhost'
|
||||
)
|
||||
user_id = os.getuid() if hasattr(os, 'getuid') else 1000
|
||||
if user_id == 0:
|
||||
kwargs['user_id'] = get_unique_uid()
|
||||
sandbox_config.user_id = get_unique_uid()
|
||||
|
||||
config = AppConfig(
|
||||
default_agent='CodeActAgent',
|
||||
runtime='docker',
|
||||
max_budget_per_task=4,
|
||||
max_iterations=max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
runtime_container_image=runtime_container_image,
|
||||
enable_auto_lint=False,
|
||||
use_host_network=False,
|
||||
# large enough timeout, since some testcases take very long to run
|
||||
timeout=300,
|
||||
**kwargs,
|
||||
),
|
||||
sandbox=sandbox_config,
|
||||
# do not mount workspace
|
||||
workspace_base=workspace_base,
|
||||
workspace_mount_path=workspace_base,
|
||||
@ -222,7 +226,7 @@ async def process_issue(
|
||||
runtime = create_runtime(config)
|
||||
await runtime.connect()
|
||||
|
||||
def on_event(evt):
|
||||
def on_event(evt: Event) -> None:
|
||||
logger.info(evt)
|
||||
|
||||
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||
@ -524,10 +528,10 @@ async def resolve_issue(
|
||||
logger.info('Finished.')
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
def int_or_none(value):
|
||||
def int_or_none(value: str) -> int | None:
|
||||
if value.lower() == 'none':
|
||||
return None
|
||||
else:
|
||||
@ -654,7 +658,7 @@ def main():
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import shutil
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -543,7 +544,7 @@ def process_all_successful_issues(
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Send a pull request to Github or Gitlab.'
|
||||
)
|
||||
@ -641,7 +642,7 @@ def main():
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ def codeact_user_response(
|
||||
return msg
|
||||
|
||||
|
||||
def cleanup():
|
||||
def cleanup() -> None:
|
||||
print('Cleaning up child processes...')
|
||||
for process in mp.active_children():
|
||||
print(f'Terminating child process: {process.name}')
|
||||
@ -115,7 +115,9 @@ def cleanup():
|
||||
process.join()
|
||||
|
||||
|
||||
def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
|
||||
def prepare_dataset(
|
||||
dataset: pd.DataFrame, output_file: str, eval_n_limit: int
|
||||
) -> pd.DataFrame:
|
||||
assert 'instance_id' in dataset.columns, (
|
||||
"Expected 'instance_id' column in the dataset. You should define your own "
|
||||
"unique identifier for each instance and use it as the 'instance_id' column."
|
||||
@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
|
||||
|
||||
def reset_logger_for_multiprocessing(
|
||||
logger: logging.Logger, instance_id: str, log_dir: str
|
||||
):
|
||||
) -> None:
|
||||
"""Reset the logger for multiprocessing.
|
||||
|
||||
Save logs to a separate file for each process, instead of trying to write to the
|
||||
@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]:
|
||||
return [int(match) for match in re.findall(pattern, body)]
|
||||
|
||||
|
||||
def get_unique_uid(start_uid=1000):
|
||||
def get_unique_uid(start_uid: int = 1000) -> int:
|
||||
existing_uids = set()
|
||||
with open('/etc/passwd', 'r') as passwd_file:
|
||||
for line in passwd_file:
|
||||
|
||||
@ -4,7 +4,9 @@ import os
|
||||
from openhands.resolver.io_utils import load_single_resolver_output
|
||||
|
||||
|
||||
def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str):
|
||||
def visualize_resolver_output(
|
||||
issue_number: int, output_dir: str, vis_method: str
|
||||
) -> None:
|
||||
output_jsonl = os.path.join(output_dir, 'output.jsonl')
|
||||
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
|
||||
if vis_method == 'json':
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user