eigent/backend/app/utils/toolkit/github_toolkit.py
wol 03f2c49b7e
Some checks failed
Remove old artifacts / remove-old-artifacts (push) Has been cancelled
init
2025-08-04 00:20:29 +08:00

108 lines
4.2 KiB
Python

from typing import Literal
from camel.toolkits import GithubToolkit as BaseGithubToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
def __init__(
self,
api_task_id: str,
access_token: str | None = None,
timeout: float | None = None,
) -> None:
super().__init__(access_token, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseGithubToolkit.create_pull_request,
lambda _,
repo_name,
file_path,
new_content,
pr_title,
body,
branch_name: f"Create PR in {repo_name} for {file_path} with title '{pr_title}', branch '{branch_name}', content '{new_content}'",
)
def create_pull_request(
self,
repo_name: str,
file_path: str,
new_content: str,
pr_title: str,
body: str,
branch_name: str,
) -> str:
return super().create_pull_request(repo_name, file_path, new_content, pr_title, body, branch_name)
@listen_toolkit(
BaseGithubToolkit.get_issue_list,
lambda _, repo_name, state="all": f"Get issue list from {repo_name} with state '{state}'",
lambda issues: f"Retrieved {len(issues)} issues",
)
def get_issue_list(
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
) -> list[dict[str, object]]:
return super().get_issue_list(repo_name, state)
@listen_toolkit(
BaseGithubToolkit.get_issue_content,
lambda _, repo_name, issue_number: f"Get content of issue {issue_number} from {repo_name}",
)
def get_issue_content(self, repo_name: str, issue_number: int) -> str:
return super().get_issue_content(repo_name, issue_number)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_list,
lambda _, repo_name, state="all": f"Get pull request list from {repo_name} with state '{state}'",
lambda prs: f"Retrieved {len(prs)} pull requests",
)
def get_pull_request_list(
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
) -> list[dict[str, object]]:
return super().get_pull_request_list(repo_name, state)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_code,
lambda _, repo_name, pr_number: f"Get code for pull request {pr_number} in {repo_name}",
lambda code: f"Retrieved {len(code)} code files",
)
def get_pull_request_code(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
return super().get_pull_request_code(repo_name, pr_number)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_comments,
lambda _, repo_name, pr_number: f"Get comments for pull request {pr_number} in {repo_name}",
lambda comments: f"Retrieved {len(comments)} comments",
)
def get_pull_request_comments(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
return super().get_pull_request_comments(repo_name, pr_number)
@listen_toolkit(
BaseGithubToolkit.get_all_file_paths,
lambda _, repo_name, path="": f"Get all file paths from {repo_name}, path '{path}'",
lambda paths: f"Retrieved {len(paths)} file paths",
)
def get_all_file_paths(self, repo_name: str, path: str = "") -> list[str]:
return super().get_all_file_paths(repo_name, path)
@listen_toolkit(
BaseGithubToolkit.retrieve_file_content,
lambda _, repo_name, file_path: f"Retrieve content of file {file_path} from {repo_name}",
lambda content: f"Retrieved content of length {len(content)}",
)
def retrieve_file_content(self, repo_name: str, file_path: str) -> str:
return super().retrieve_file_content(repo_name, file_path)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
if env("GITHUB_ACCESS_TOKEN"):
return GithubToolkit(api_task_id).get_tools()
else:
return []