Some checks failed
Remove old artifacts / remove-old-artifacts (push) Has been cancelled
108 lines
4.2 KiB
Python
108 lines
4.2 KiB
Python
from typing import Literal
|
|
from camel.toolkits import GithubToolkit as BaseGithubToolkit
|
|
from camel.toolkits.function_tool import FunctionTool
|
|
from app.component.environment import env
|
|
from app.service.task import Agents
|
|
from app.utils.listen.toolkit_listen import listen_toolkit
|
|
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
|
|
|
|
|
class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
|
|
agent_name: str = Agents.developer_agent
|
|
|
|
def __init__(
|
|
self,
|
|
api_task_id: str,
|
|
access_token: str | None = None,
|
|
timeout: float | None = None,
|
|
) -> None:
|
|
super().__init__(access_token, timeout)
|
|
self.api_task_id = api_task_id
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.create_pull_request,
|
|
lambda _,
|
|
repo_name,
|
|
file_path,
|
|
new_content,
|
|
pr_title,
|
|
body,
|
|
branch_name: f"Create PR in {repo_name} for {file_path} with title '{pr_title}', branch '{branch_name}', content '{new_content}'",
|
|
)
|
|
def create_pull_request(
|
|
self,
|
|
repo_name: str,
|
|
file_path: str,
|
|
new_content: str,
|
|
pr_title: str,
|
|
body: str,
|
|
branch_name: str,
|
|
) -> str:
|
|
return super().create_pull_request(repo_name, file_path, new_content, pr_title, body, branch_name)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_issue_list,
|
|
lambda _, repo_name, state="all": f"Get issue list from {repo_name} with state '{state}'",
|
|
lambda issues: f"Retrieved {len(issues)} issues",
|
|
)
|
|
def get_issue_list(
|
|
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
|
|
) -> list[dict[str, object]]:
|
|
return super().get_issue_list(repo_name, state)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_issue_content,
|
|
lambda _, repo_name, issue_number: f"Get content of issue {issue_number} from {repo_name}",
|
|
)
|
|
def get_issue_content(self, repo_name: str, issue_number: int) -> str:
|
|
return super().get_issue_content(repo_name, issue_number)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_pull_request_list,
|
|
lambda _, repo_name, state="all": f"Get pull request list from {repo_name} with state '{state}'",
|
|
lambda prs: f"Retrieved {len(prs)} pull requests",
|
|
)
|
|
def get_pull_request_list(
|
|
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
|
|
) -> list[dict[str, object]]:
|
|
return super().get_pull_request_list(repo_name, state)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_pull_request_code,
|
|
lambda _, repo_name, pr_number: f"Get code for pull request {pr_number} in {repo_name}",
|
|
lambda code: f"Retrieved {len(code)} code files",
|
|
)
|
|
def get_pull_request_code(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
|
|
return super().get_pull_request_code(repo_name, pr_number)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_pull_request_comments,
|
|
lambda _, repo_name, pr_number: f"Get comments for pull request {pr_number} in {repo_name}",
|
|
lambda comments: f"Retrieved {len(comments)} comments",
|
|
)
|
|
def get_pull_request_comments(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
|
|
return super().get_pull_request_comments(repo_name, pr_number)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.get_all_file_paths,
|
|
lambda _, repo_name, path="": f"Get all file paths from {repo_name}, path '{path}'",
|
|
lambda paths: f"Retrieved {len(paths)} file paths",
|
|
)
|
|
def get_all_file_paths(self, repo_name: str, path: str = "") -> list[str]:
|
|
return super().get_all_file_paths(repo_name, path)
|
|
|
|
@listen_toolkit(
|
|
BaseGithubToolkit.retrieve_file_content,
|
|
lambda _, repo_name, file_path: f"Retrieve content of file {file_path} from {repo_name}",
|
|
lambda content: f"Retrieved content of length {len(content)}",
|
|
)
|
|
def retrieve_file_content(self, repo_name: str, file_path: str) -> str:
|
|
return super().retrieve_file_content(repo_name, file_path)
|
|
|
|
@classmethod
|
|
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
|
if env("GITHUB_ACCESS_TOKEN"):
|
|
return GithubToolkit(api_task_id).get_tools()
|
|
else:
|
|
return []
|