mirror of
https://github.com/camel-ai/owl.git
synced 2026-03-22 05:57:17 +08:00
refactor: Update with camel version 0.2.23
This commit is contained in:
29
.pre-commit-config.yaml
Normal file
29
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: 'v0.7.4'
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
|
||||||
|
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||||
|
- id: ruff-format
|
||||||
|
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
name: Check mypy
|
||||||
|
entry: mypy --namespace-packages -p owl
|
||||||
|
language: python
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: check-license
|
||||||
|
name: Check License
|
||||||
|
entry: python licenses/update_license.py . licenses/license_template.txt
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
exclude: ^docs/cookbooks/ # Ignore files under docs/cookbooks
|
||||||
43
README.md
43
README.md
@@ -102,36 +102,31 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
|
|||||||
|
|
||||||
# 🛠️ Installation
|
# 🛠️ Installation
|
||||||
|
|
||||||
## **Clone the Github repository**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Clone github repo
|
||||||
git clone https://github.com/camel-ai/owl.git
|
git clone https://github.com/camel-ai/owl.git
|
||||||
|
|
||||||
|
# Change directory into project directory
|
||||||
cd owl
|
cd owl
|
||||||
```
|
|
||||||
|
|
||||||
## **Set up Environment**
|
# Install uv if you don't have it already
|
||||||
|
pip install uv
|
||||||
|
|
||||||
Using Conda (recommended):
|
# Create a virtual environment and install dependencies
|
||||||
```bash
|
# We support using Python 3.10, 3.11, 3.12
|
||||||
conda create -n owl python=3.11
|
uv venv .venv --python=3.10
|
||||||
conda activate owl
|
|
||||||
```
|
|
||||||
|
|
||||||
Using venv (alternative):
|
# Activate the virtual environment
|
||||||
```bash
|
# For macOS/Linux
|
||||||
python -m venv owl_env
|
source .venv/bin/activate
|
||||||
# On Windows
|
# For Windows
|
||||||
owl_env\Scripts\activate
|
.venv\Scripts\activate
|
||||||
# On Unix or MacOS
|
|
||||||
source owl_env/bin/activate
|
|
||||||
```
|
|
||||||
|
|
||||||
|
# Install CAMEL with all dependencies
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
## **Install Dependencies**
|
# Exit the virtual environment when done
|
||||||
|
deactivate
|
||||||
```bash
|
|
||||||
python -m pip install -r requirements.txt
|
|
||||||
playwright install
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## **Setup Environment Variables**
|
## **Setup Environment Variables**
|
||||||
@@ -210,7 +205,7 @@ question = "Task description here."
|
|||||||
society = construct_society(question)
|
society = construct_society(question)
|
||||||
answer, chat_history, token_count = run_society(society)
|
answer, chat_history, token_count = run_society(society)
|
||||||
|
|
||||||
print(f"Answer: {answer}")
|
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||||
```
|
```
|
||||||
|
|
||||||
For uploading files, simply provide the file path along with your question:
|
For uploading files, simply provide the file path along with your question:
|
||||||
@@ -221,7 +216,7 @@ question = "What is in the given DOCX file? Here is the file path: tmp/example.d
|
|||||||
|
|
||||||
society = construct_society(question)
|
society = construct_society(question)
|
||||||
answer, chat_history, token_count = run_society(society)
|
answer, chat_history, token_count = run_society(society)
|
||||||
print(f"Answer: {answer}")
|
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||||
```
|
```
|
||||||
|
|
||||||
OWL will then automatically invoke document-related tools to process the file and extract the answer.
|
OWL will then automatically invoke document-related tools to process the file and extract the answer.
|
||||||
|
|||||||
37
README_zh.md
37
README_zh.md
@@ -104,31 +104,30 @@ https://private-user-images.githubusercontent.com/55657767/420212194-e813fc05-13
|
|||||||
## **克隆 Github 仓库**
|
## **克隆 Github 仓库**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# 克隆 GitHub 仓库
|
||||||
git clone https://github.com/camel-ai/owl.git
|
git clone https://github.com/camel-ai/owl.git
|
||||||
|
|
||||||
|
# 进入项目目录
|
||||||
cd owl
|
cd owl
|
||||||
```
|
|
||||||
|
|
||||||
## **设置环境**
|
# 如果你还没有安装 uv,请先安装
|
||||||
|
pip install uv
|
||||||
|
|
||||||
使用 Conda(推荐):
|
# 创建虚拟环境并安装依赖
|
||||||
```bash
|
# 我们支持使用 Python 3.10、3.11、3.12
|
||||||
conda create -n owl python=3.11
|
uv venv .venv --python=3.10
|
||||||
conda activate owl
|
|
||||||
```
|
|
||||||
|
|
||||||
使用 venv(备用):
|
# 激活虚拟环境
|
||||||
```bash
|
# 对于 macOS/Linux
|
||||||
python -m venv owl_env
|
source .venv/bin/activate
|
||||||
# Windows 系统
|
# 对于 Windows
|
||||||
owl_env\Scripts\activate
|
.venv\Scripts\activate
|
||||||
# Unix 或 MacOS 系统
|
|
||||||
source owl_env/bin/activate
|
|
||||||
```
|
|
||||||
|
|
||||||
## **安装依赖**
|
# 安装 CAMEL 及其所有依赖
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
```bash
|
# 完成后退出虚拟环境
|
||||||
python -m pip install -r requirements.txt
|
deactivate
|
||||||
```
|
```
|
||||||
|
|
||||||
## **设置环境变量**
|
## **设置环境变量**
|
||||||
@@ -201,7 +200,7 @@ question = "Task description here."
|
|||||||
society = construct_society(question)
|
society = construct_society(question)
|
||||||
answer, chat_history, token_count = run_society(society)
|
answer, chat_history, token_count = run_society(society)
|
||||||
|
|
||||||
print(f"Answer: {answer}")
|
print(f"\033[94mAnswer: {answer}\033[0m")
|
||||||
```
|
```
|
||||||
|
|
||||||
上传文件时,只需提供文件路径和问题:
|
上传文件时,只需提供文件路径和问题:
|
||||||
|
|||||||
@@ -39,43 +39,37 @@ def update_license_in_file(
|
|||||||
start_line_start_with: str,
|
start_line_start_with: str,
|
||||||
end_line_start_with: str,
|
end_line_start_with: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
with open(
|
with open(file_path, "r", encoding="utf-8") as f: # for windows compatibility
|
||||||
file_path, 'r', encoding='utf-8'
|
|
||||||
) as f: # for windows compatibility
|
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
with open(license_template_path, 'r', encoding='utf-8') as f:
|
with open(license_template_path, "r", encoding="utf-8") as f:
|
||||||
new_license = f.read().strip()
|
new_license = f.read().strip()
|
||||||
|
|
||||||
maybe_existing_licenses = re.findall(
|
maybe_existing_licenses = re.findall(
|
||||||
r'^#.*?(?=\n)', content, re.MULTILINE | re.DOTALL
|
r"^#.*?(?=\n)", content, re.MULTILINE | re.DOTALL
|
||||||
)
|
)
|
||||||
start_index = fine_license_start_line(
|
start_index = fine_license_start_line(
|
||||||
maybe_existing_licenses, start_line_start_with
|
maybe_existing_licenses, start_line_start_with
|
||||||
)
|
)
|
||||||
end_index = find_license_end_line(
|
end_index = find_license_end_line(maybe_existing_licenses, end_line_start_with)
|
||||||
maybe_existing_licenses, end_line_start_with
|
|
||||||
)
|
|
||||||
if start_index is not None and end_index is not None:
|
if start_index is not None and end_index is not None:
|
||||||
maybe_existing_licenses = maybe_existing_licenses[
|
maybe_existing_licenses = maybe_existing_licenses[start_index : end_index + 1]
|
||||||
start_index : end_index + 1
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
maybe_existing_licenses = None
|
maybe_existing_licenses = None
|
||||||
if maybe_existing_licenses:
|
if maybe_existing_licenses:
|
||||||
maybe_old_licenses = '\n'.join(maybe_existing_licenses)
|
maybe_old_licenses = "\n".join(maybe_existing_licenses)
|
||||||
if maybe_old_licenses.strip() != new_license.strip():
|
if maybe_old_licenses.strip() != new_license.strip():
|
||||||
replaced_content = content.replace(maybe_old_licenses, new_license)
|
replaced_content = content.replace(maybe_old_licenses, new_license)
|
||||||
with open(file_path, 'w') as f:
|
with open(file_path, "w") as f:
|
||||||
f.write(replaced_content)
|
f.write(replaced_content)
|
||||||
print(f'Replaced license in {file_path}')
|
print(f"Replaced license in {file_path}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
with open(file_path, 'w') as f:
|
with open(file_path, "w") as f:
|
||||||
f.write(new_license + '\n' + content)
|
f.write(new_license + "\n" + content)
|
||||||
print(f'Added license to {file_path}')
|
print(f"Added license to {file_path}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -87,16 +81,16 @@ def update_license_in_directory(
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Check if directory exists
|
# Check if directory exists
|
||||||
if not os.path.isdir(directory_path):
|
if not os.path.isdir(directory_path):
|
||||||
raise NotADirectoryError(f'{directory_path} is not a directory')
|
raise NotADirectoryError(f"{directory_path} is not a directory")
|
||||||
# Check if license template exists
|
# Check if license template exists
|
||||||
if not os.path.isfile(license_template_path):
|
if not os.path.isfile(license_template_path):
|
||||||
raise FileNotFoundError(f'{license_template_path} not found')
|
raise FileNotFoundError(f"{license_template_path} not found")
|
||||||
|
|
||||||
file_count = 0
|
file_count = 0
|
||||||
for py_files in Path(directory_path).rglob("*.py"):
|
for py_files in Path(directory_path).rglob("*.py"):
|
||||||
if py_files.name.startswith('.'):
|
if py_files.name.startswith("."):
|
||||||
continue
|
continue
|
||||||
if any(part.startswith('.') for part in py_files.parts):
|
if any(part.startswith(".") for part in py_files.parts):
|
||||||
continue
|
continue
|
||||||
if update_license_in_file(
|
if update_license_in_file(
|
||||||
py_files,
|
py_files,
|
||||||
@@ -106,10 +100,10 @@ def update_license_in_directory(
|
|||||||
):
|
):
|
||||||
file_count += 1
|
file_count += 1
|
||||||
|
|
||||||
print(f'License updated in {file_count} files')
|
print(f"License updated in {file_count} files")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) < 3:
|
if len(sys.argv) < 3:
|
||||||
print(
|
print(
|
||||||
"Usage from command line: "
|
"Usage from command line: "
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from camel.logger import disable_logging, enable_logging, set_log_level
|
|
||||||
|
|
||||||
__version__ = '0.2.11'
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'__version__',
|
|
||||||
'camel',
|
|
||||||
'disable_logging',
|
|
||||||
'enable_logging',
|
|
||||||
'set_log_level',
|
|
||||||
]
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .base import BaseAgent
|
|
||||||
from .chat_agent import ChatAgent
|
|
||||||
from .critic_agent import CriticAgent
|
|
||||||
from .embodied_agent import EmbodiedAgent
|
|
||||||
from .knowledge_graph_agent import KnowledgeGraphAgent
|
|
||||||
from .role_assignment_agent import RoleAssignmentAgent
|
|
||||||
from .search_agent import SearchAgent
|
|
||||||
from .task_agent import (
|
|
||||||
TaskCreationAgent,
|
|
||||||
TaskPlannerAgent,
|
|
||||||
TaskPrioritizationAgent,
|
|
||||||
TaskSpecifyAgent,
|
|
||||||
)
|
|
||||||
from .tool_agents.base import BaseToolAgent
|
|
||||||
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseAgent',
|
|
||||||
'ChatAgent',
|
|
||||||
'TaskSpecifyAgent',
|
|
||||||
'TaskPlannerAgent',
|
|
||||||
'TaskCreationAgent',
|
|
||||||
'TaskPrioritizationAgent',
|
|
||||||
'CriticAgent',
|
|
||||||
'BaseToolAgent',
|
|
||||||
'HuggingFaceToolAgent',
|
|
||||||
'EmbodiedAgent',
|
|
||||||
'RoleAssignmentAgent',
|
|
||||||
'SearchAgent',
|
|
||||||
'KnowledgeGraphAgent',
|
|
||||||
]
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
r"""An abstract base class for all CAMEL agents."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
r"""Resets the agent to its initial state."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def step(self, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
r"""Performs a single step of the agent."""
|
|
||||||
pass
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,202 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Optional, Sequence
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.memories import AgentMemory
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.responses import ChatAgentResponse
|
|
||||||
from camel.utils import get_first_int, print_text_animated
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="CriticAgent")
|
|
||||||
class CriticAgent(ChatAgent):
|
|
||||||
r"""A class for the critic agent that assists in selecting an option.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
system_message (BaseMessage): The system message for the critic
|
|
||||||
agent.
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
message_window_size (int, optional): The maximum number of previous
|
|
||||||
messages to include in the context window. If `None`, no windowing
|
|
||||||
is performed. (default: :obj:`6`)
|
|
||||||
retry_attempts (int, optional): The number of retry attempts if the
|
|
||||||
critic fails to return a valid option. (default: :obj:`2`)
|
|
||||||
verbose (bool, optional): Whether to print the critic's messages.
|
|
||||||
logger_color (Any): The color of the menu options displayed to the
|
|
||||||
user. (default: :obj:`Fore.MAGENTA`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
system_message: BaseMessage,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
memory: Optional[AgentMemory] = None,
|
|
||||||
message_window_size: int = 6,
|
|
||||||
retry_attempts: int = 2,
|
|
||||||
verbose: bool = False,
|
|
||||||
logger_color: Any = Fore.MAGENTA,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
system_message,
|
|
||||||
model=model,
|
|
||||||
memory=memory,
|
|
||||||
message_window_size=message_window_size,
|
|
||||||
)
|
|
||||||
self.options_dict: Dict[str, str] = dict()
|
|
||||||
self.retry_attempts = retry_attempts
|
|
||||||
self.verbose = verbose
|
|
||||||
self.logger_color = logger_color
|
|
||||||
|
|
||||||
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
|
||||||
r"""Flattens the options to the critic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string containing the flattened options to the critic.
|
|
||||||
"""
|
|
||||||
options = [message.content for message in messages]
|
|
||||||
flatten_options = (
|
|
||||||
f"> Proposals from "
|
|
||||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
|
||||||
"Please choose an option:\n"
|
|
||||||
)
|
|
||||||
for index, option in enumerate(options):
|
|
||||||
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
|
||||||
self.options_dict[str(index + 1)] = option
|
|
||||||
format = (
|
|
||||||
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
|
||||||
"and then your explanation and comparison: "
|
|
||||||
)
|
|
||||||
return flatten_options + format
|
|
||||||
|
|
||||||
def get_option(self, input_message: BaseMessage) -> str:
|
|
||||||
r"""Gets the option selected by the critic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_message (BaseMessage): A `BaseMessage` object representing
|
|
||||||
the input message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The option selected by the critic.
|
|
||||||
"""
|
|
||||||
# TODO: Add support for editing options by the critic.
|
|
||||||
msg_content = input_message.content
|
|
||||||
i = 0
|
|
||||||
while i < self.retry_attempts:
|
|
||||||
critic_response = self.step(input_message)
|
|
||||||
|
|
||||||
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got None critic messages.")
|
|
||||||
if critic_response.terminated:
|
|
||||||
raise RuntimeError("Critic step failed.")
|
|
||||||
|
|
||||||
critic_msg = critic_response.msg
|
|
||||||
if self.verbose:
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + "\n> Critic response: "
|
|
||||||
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
|
||||||
)
|
|
||||||
choice = self.parse_critic(critic_msg)
|
|
||||||
|
|
||||||
if choice in self.options_dict:
|
|
||||||
return self.options_dict[choice]
|
|
||||||
else:
|
|
||||||
input_message = BaseMessage(
|
|
||||||
role_name=input_message.role_name,
|
|
||||||
role_type=input_message.role_type,
|
|
||||||
meta_dict=input_message.meta_dict,
|
|
||||||
content="> Invalid choice. Please choose again.\n"
|
|
||||||
+ msg_content,
|
|
||||||
)
|
|
||||||
i += 1
|
|
||||||
warnings.warn(
|
|
||||||
"Critic failed to get a valid option. "
|
|
||||||
f"After {self.retry_attempts} attempts. "
|
|
||||||
"Returning a random option."
|
|
||||||
)
|
|
||||||
return random.choice(list(self.options_dict.values()))
|
|
||||||
|
|
||||||
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
|
||||||
r"""Parses the critic's message and extracts the choice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
|
||||||
critic's response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[str]: The critic's choice as a string, or None if the
|
|
||||||
message could not be parsed.
|
|
||||||
"""
|
|
||||||
choice = str(get_first_int(critic_msg.content))
|
|
||||||
return choice
|
|
||||||
|
|
||||||
def reduce_step(
|
|
||||||
self,
|
|
||||||
input_messages: Sequence[BaseMessage],
|
|
||||||
) -> ChatAgentResponse:
|
|
||||||
r"""Performs one step of the conversation by flattening options to the
|
|
||||||
critic, getting the option, and parsing the choice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
|
||||||
objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
|
||||||
critic's choice.
|
|
||||||
"""
|
|
||||||
meta_chat_message = BaseMessage(
|
|
||||||
role_name=input_messages[0].role_name,
|
|
||||||
role_type=input_messages[0].role_type,
|
|
||||||
meta_dict=input_messages[0].meta_dict,
|
|
||||||
content="",
|
|
||||||
)
|
|
||||||
|
|
||||||
flatten_options = self.flatten_options(input_messages)
|
|
||||||
if self.verbose:
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
|
||||||
)
|
|
||||||
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
|
||||||
|
|
||||||
option = self.get_option(input_msg)
|
|
||||||
output_msg = meta_chat_message.create_new_instance(option)
|
|
||||||
|
|
||||||
# TODO: The return `info` can be improved.
|
|
||||||
return ChatAgentResponse(
|
|
||||||
msgs=[output_msg],
|
|
||||||
terminated=False,
|
|
||||||
info={},
|
|
||||||
)
|
|
||||||
@@ -1,303 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.logger import get_logger
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.prompts import TextPrompt
|
|
||||||
from camel.types import RoleType
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="DeductiveReasonerAgent")
|
|
||||||
class DeductiveReasonerAgent(ChatAgent):
|
|
||||||
r"""An agent responsible for deductive reasoning. Model of deductive
|
|
||||||
reasoning:
|
|
||||||
- L: A ⊕ C -> q * B
|
|
||||||
- A represents the known starting state.
|
|
||||||
- B represents the known target state.
|
|
||||||
- C represents the conditions required to transition from A to B.
|
|
||||||
- Q represents the quality or effectiveness of the transition from
|
|
||||||
A to B.
|
|
||||||
- L represents the path or process from A to B.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
) -> None:
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Insight Agent",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You assign roles based on tasks.",
|
|
||||||
)
|
|
||||||
super().__init__(system_message, model=model)
|
|
||||||
|
|
||||||
def deduce_conditions_and_quality(
|
|
||||||
self,
|
|
||||||
starting_state: str,
|
|
||||||
target_state: str,
|
|
||||||
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
|
||||||
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
|
||||||
r"""Derives the conditions and quality from the starting state and the
|
|
||||||
target state based on the model of the deductive reasoning and the
|
|
||||||
knowledge base. It can optionally consider the roles involved in the
|
|
||||||
scenario, which allows tailoring the output more closely to the AI
|
|
||||||
agent's environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
starting_state (str): The initial or starting state from which
|
|
||||||
conditions are deduced.
|
|
||||||
target_state (str): The target state of the task.
|
|
||||||
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
|
||||||
descriptions of the roles. (default: :obj:`None`)
|
|
||||||
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
|
||||||
dictionary describing the roles involved in the scenario. This
|
|
||||||
is optional and can be used to provide a context for the
|
|
||||||
CAMEL's role-playing, enabling the generation of more relevant
|
|
||||||
and tailored conditions and quality assessments. This could be
|
|
||||||
generated using a `RoleAssignmentAgent()` or defined manually
|
|
||||||
by the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
|
||||||
extracted data from the message. The dictionary contains three
|
|
||||||
keys:
|
|
||||||
- 'conditions': A list where each key is a condition ID and
|
|
||||||
each value is the corresponding condition text.
|
|
||||||
- 'labels': A list of label strings extracted from the message.
|
|
||||||
- 'quality': A string of quality assessment strings extracted
|
|
||||||
from the message.
|
|
||||||
"""
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
|
||||||
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
|
||||||
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
|
||||||
CONTENT to help you complete the TASK.
|
|
||||||
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
|
||||||
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
|
||||||
|
|
||||||
===== MODELING OF DEDUCTIVE REASONING =====
|
|
||||||
You are tasked with understanding a mathematical model based on the components
|
|
||||||
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
|
||||||
- $A$ represents the known starting state.
|
|
||||||
- $B$ represents the known target state.
|
|
||||||
- $C$ represents the conditions required to transition from $A$ to $B$.
|
|
||||||
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
|
||||||
$B$.
|
|
||||||
- $L$ represents the path or process from $A$ to $B$.
|
|
||||||
|
|
||||||
===== THOUGHT OF DEDUCTIVE REASONING =====
|
|
||||||
1. Define the Parameters of A and B:
|
|
||||||
- Characterization: Before delving into transitions, thoroughly understand
|
|
||||||
the nature and boundaries of both $A$ and $B$. This includes the type,
|
|
||||||
properties, constraints, and possible interactions between the two.
|
|
||||||
- Contrast and Compare: Highlight the similarities and differences between
|
|
||||||
$A$ and $B$. This comparative analysis will give an insight into what
|
|
||||||
needs changing and what remains constant.
|
|
||||||
2. Historical & Empirical Analysis:
|
|
||||||
- Previous Transitions according to the Knowledge Base of GPT: (if
|
|
||||||
applicable) Extract conditions and patterns from the historical instances
|
|
||||||
where a similar transition from a state comparable to $A$ moved towards
|
|
||||||
$B$.
|
|
||||||
- Scientific Principles: (if applicable) Consider the underlying
|
|
||||||
scientific principles governing or related to the states and their
|
|
||||||
transition. For example, if $A$ and $B$ are physical states, laws of
|
|
||||||
physics might apply.
|
|
||||||
3. Logical Deduction of Conditions ($C$):
|
|
||||||
- Direct Path Analysis: What are the immediate and direct conditions
|
|
||||||
required to move from $A$ to $B$?
|
|
||||||
- Intermediate States: Are there states between $A$ and $B$ that must be
|
|
||||||
traversed or can be used to make the transition smoother or more
|
|
||||||
efficient? If yes, what is the content?
|
|
||||||
- Constraints & Limitations: Identify potential barriers or restrictions
|
|
||||||
in moving from $A$ to $B$. These can be external (e.g., environmental
|
|
||||||
factors) or internal (properties of $A$ or $B$).
|
|
||||||
- Resource and Information Analysis: What resources and information are
|
|
||||||
required for the transition? This could be time, entity, factor, code
|
|
||||||
language, software platform, unknowns, etc.
|
|
||||||
- External Influences: Consider socio-economic, political, or
|
|
||||||
environmental factors (if applicable) that could influence the transition
|
|
||||||
conditions.
|
|
||||||
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
|
||||||
no matter how unconventional they might seem. Utilize analogies,
|
|
||||||
metaphors, or brainstorming techniques to envision possible conditions or
|
|
||||||
paths from $A$ to $B$.
|
|
||||||
- The conditions $C$ should be multiple but in one sentence. And each
|
|
||||||
condition should be concerned with one aspect/entity.
|
|
||||||
4. Entity/Label Recognition of Conditions ($C$):
|
|
||||||
- Identify and categorize entities of Conditions ($C$) such as the names,
|
|
||||||
locations, dates, specific technical terms or contextual parameters that
|
|
||||||
might be associated with events, innovations post-2022.
|
|
||||||
- The output of the entities/labels will be used as tags or labels for
|
|
||||||
semantic similarity searches. The entities/labels may be the words, or
|
|
||||||
phrases, each of them should contain valuable, high information entropy
|
|
||||||
information, and should be independent.
|
|
||||||
- Ensure that the identified entities are formatted in a manner suitable
|
|
||||||
for database indexing and retrieval. Organize the entities into
|
|
||||||
categories, and combine the category with its instance into a continuous
|
|
||||||
phrase, without using colons or other separators.
|
|
||||||
- Format these entities for database indexing: output the category rather
|
|
||||||
than its instance/content into a continuous phrase. For example, instead
|
|
||||||
of "Jan. 02", identify it as "Event time".
|
|
||||||
5. Quality Assessment ($Q$):
|
|
||||||
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
|
||||||
measures the resources used versus the desired outcome?
|
|
||||||
- Effectiveness: Did the transition achieve the desired outcome or was the
|
|
||||||
target state achieved as intended?
|
|
||||||
- Safety & Risks: Assess any risks associated with the transition and the
|
|
||||||
measures to mitigate them.
|
|
||||||
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
|
||||||
and adjust the quality of transition, making it more adaptive.
|
|
||||||
6. Iterative Evaluation:
|
|
||||||
- Test & Refine: Based on the initially deduced conditions and assessed
|
|
||||||
quality, iterate the process to refine and optimize the transition. This
|
|
||||||
might involve tweaking conditions, employing different paths, or changing
|
|
||||||
resources.
|
|
||||||
- Feedback Integration: Use feedback to make improvements and increase the
|
|
||||||
quality of the transition.
|
|
||||||
7. Real-world scenarios often present challenges that may not be captured by
|
|
||||||
models and frameworks. While using the model, maintain an adaptive mindset:
|
|
||||||
- Scenario Exploration: Continuously imagine various possible scenarios,
|
|
||||||
both positive and negative, to prepare for unexpected events.
|
|
||||||
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
|
||||||
process ($L$) if unforeseen challenges arise.
|
|
||||||
- Feedback Integration: Rapidly integrate feedback from actual
|
|
||||||
implementations to adjust the model's application, ensuring relevancy and
|
|
||||||
effectiveness.
|
|
||||||
|
|
||||||
===== TASK =====
|
|
||||||
Given the starting state $A$ and the target state $B$, assuming that a path
|
|
||||||
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
|
||||||
necessary conditions $C$ and the quality $Q$ of the transition?
|
|
||||||
|
|
||||||
===== STARTING STATE $A$ =====
|
|
||||||
{starting_state}
|
|
||||||
|
|
||||||
===== TARGET STATE $B$ =====
|
|
||||||
{target_state}
|
|
||||||
|
|
||||||
{role_with_description_prompt}
|
|
||||||
===== ANSWER TEMPLATE =====
|
|
||||||
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
|
||||||
- Historical & Empirical Analysis:\n<BLANK>/None
|
|
||||||
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
|
||||||
condition <NUM>:
|
|
||||||
<BLANK>.
|
|
||||||
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
|
||||||
square brackets)
|
|
||||||
- Quality Assessment ($Q$) (do not use symbols):
|
|
||||||
<BLANK>.
|
|
||||||
- Iterative Evaluation:\n<BLANK>/None"""
|
|
||||||
|
|
||||||
if role_descriptions_dict is not None:
|
|
||||||
role_names = role_descriptions_dict.keys()
|
|
||||||
role_with_description_prompt = (
|
|
||||||
"===== ROLES WITH DESCRIPTIONS =====\n"
|
|
||||||
+ "\n".join(
|
|
||||||
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
|
||||||
for role_name in role_names
|
|
||||||
)
|
|
||||||
+ "\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
role_with_description_prompt = ""
|
|
||||||
deduce_prompt = TextPrompt(deduce_prompt)
|
|
||||||
|
|
||||||
deduce = deduce_prompt.format(
|
|
||||||
starting_state=starting_state,
|
|
||||||
target_state=target_state,
|
|
||||||
role_with_description_prompt=role_with_description_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Deductive Reasoner", content=deduce
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.step(
|
|
||||||
input_message=conditions_and_quality_generation_msg
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.terminated:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Deduction failed. Error:\n" + f"{response.info}"
|
|
||||||
)
|
|
||||||
msg: BaseMessage = response.msg
|
|
||||||
logger.info(f"Message content:\n{msg.content}")
|
|
||||||
|
|
||||||
# Extract the conditions from the message
|
|
||||||
conditions_dict = {
|
|
||||||
f"condition {i}": cdt.replace("<", "")
|
|
||||||
.replace(">", "")
|
|
||||||
.strip()
|
|
||||||
.strip('\n')
|
|
||||||
for i, cdt in re.findall(
|
|
||||||
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
|
||||||
msg.content,
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Extract the labels from the message
|
|
||||||
labels = [
|
|
||||||
label.strip().strip('\n').strip("\"'")
|
|
||||||
for label in re.findall(
|
|
||||||
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
|
||||||
msg.content,
|
|
||||||
re.DOTALL,
|
|
||||||
)[0].split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Extract the quality from the message
|
|
||||||
quality = next(
|
|
||||||
q.strip().strip('\n')
|
|
||||||
for q in re.findall(
|
|
||||||
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
|
||||||
r"\n(.+?)- Iterative",
|
|
||||||
msg.content,
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert them into JSON format
|
|
||||||
conditions_and_quality_json: Dict[
|
|
||||||
str, Union[List[str], Dict[str, str]]
|
|
||||||
] = {}
|
|
||||||
conditions_and_quality_json["conditions"] = conditions_dict
|
|
||||||
conditions_and_quality_json["labels"] = labels
|
|
||||||
conditions_and_quality_json["evaluate_quality"] = quality
|
|
||||||
|
|
||||||
return conditions_and_quality_json
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.agents.tool_agents.base import BaseToolAgent
|
|
||||||
from camel.interpreters import (
|
|
||||||
BaseInterpreter,
|
|
||||||
InternalPythonInterpreter,
|
|
||||||
SubprocessInterpreter,
|
|
||||||
)
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.responses import ChatAgentResponse
|
|
||||||
from camel.utils import print_text_animated
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="EmbodiedAgent")
|
|
||||||
class EmbodiedAgent(ChatAgent):
|
|
||||||
r"""Class for managing conversations of CAMEL Embodied Agents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
system_message (BaseMessage): The system message for the chat agent.
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
message_window_size (int, optional): The maximum number of previous
|
|
||||||
messages to include in the context window. If `None`, no windowing
|
|
||||||
is performed. (default: :obj:`None`)
|
|
||||||
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
|
||||||
the embodied agent. (default: :obj:`None`)
|
|
||||||
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
|
||||||
execute codes. If `code_interpreter` and `tool_agent` are both
|
|
||||||
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
|
||||||
is `None` and `tool_agents` is not `None`, default to
|
|
||||||
`InternalPythonInterpreter`. (default: :obj:`None`)
|
|
||||||
verbose (bool, optional): Whether to print the critic's messages.
|
|
||||||
logger_color (Any): The color of the logger displayed to the user.
|
|
||||||
(default: :obj:`Fore.MAGENTA`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
system_message: BaseMessage,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
message_window_size: Optional[int] = None,
|
|
||||||
tool_agents: Optional[List[BaseToolAgent]] = None,
|
|
||||||
code_interpreter: Optional[BaseInterpreter] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
logger_color: Any = Fore.MAGENTA,
|
|
||||||
) -> None:
|
|
||||||
self.tool_agents = tool_agents
|
|
||||||
self.code_interpreter: BaseInterpreter
|
|
||||||
if code_interpreter is not None:
|
|
||||||
self.code_interpreter = code_interpreter
|
|
||||||
elif self.tool_agents:
|
|
||||||
self.code_interpreter = InternalPythonInterpreter()
|
|
||||||
else:
|
|
||||||
self.code_interpreter = SubprocessInterpreter()
|
|
||||||
|
|
||||||
if self.tool_agents:
|
|
||||||
system_message = self._set_tool_agents(system_message)
|
|
||||||
self.verbose = verbose
|
|
||||||
self.logger_color = logger_color
|
|
||||||
super().__init__(
|
|
||||||
system_message=system_message,
|
|
||||||
model=model,
|
|
||||||
message_window_size=message_window_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
|
||||||
action_space_prompt = self._get_tool_agents_prompt()
|
|
||||||
result_message = system_message.create_new_instance(
|
|
||||||
content=system_message.content.format(
|
|
||||||
action_space=action_space_prompt
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if self.tool_agents is not None:
|
|
||||||
self.code_interpreter.update_action_space(
|
|
||||||
{tool.name: tool for tool in self.tool_agents}
|
|
||||||
)
|
|
||||||
return result_message
|
|
||||||
|
|
||||||
def _get_tool_agents_prompt(self) -> str:
|
|
||||||
r"""Returns the action space prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The action space prompt.
|
|
||||||
"""
|
|
||||||
if self.tool_agents is not None:
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
f"*** {tool.name} ***:\n {tool.description}"
|
|
||||||
for tool in self.tool_agents
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def get_tool_agent_names(self) -> List[str]:
|
|
||||||
r"""Returns the names of tool agents.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: The names of tool agents.
|
|
||||||
"""
|
|
||||||
if self.tool_agents is not None:
|
|
||||||
return [tool.name for tool in self.tool_agents]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# ruff: noqa: E501
|
|
||||||
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
|
||||||
r"""Performs a step in the conversation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_message (BaseMessage): The input message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatAgentResponse: A struct containing the output messages,
|
|
||||||
a boolean indicating whether the chat session has terminated,
|
|
||||||
and information about the chat session.
|
|
||||||
"""
|
|
||||||
response = super().step(input_message)
|
|
||||||
|
|
||||||
if response.msgs is None or len(response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got None output messages.")
|
|
||||||
if response.terminated:
|
|
||||||
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
|
||||||
|
|
||||||
# NOTE: Only single output messages are supported
|
|
||||||
explanations, codes = response.msg.extract_text_and_code_prompts()
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
for explanation, code in zip(explanations, codes):
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + f"> Explanation:\n{explanation}"
|
|
||||||
)
|
|
||||||
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
|
||||||
|
|
||||||
if len(explanations) > len(codes):
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.msg.content
|
|
||||||
|
|
||||||
if codes is not None:
|
|
||||||
try:
|
|
||||||
content = "\n> Executed Results:\n"
|
|
||||||
for block_idx, code in enumerate(codes):
|
|
||||||
executed_output = self.code_interpreter.run(
|
|
||||||
code, code.code_type
|
|
||||||
)
|
|
||||||
content += (
|
|
||||||
f"Executing code block {block_idx}: {{\n"
|
|
||||||
+ executed_output
|
|
||||||
+ "}\n"
|
|
||||||
)
|
|
||||||
except InterruptedError as e:
|
|
||||||
content = (
|
|
||||||
f"\n> Running code fail: {e}\n"
|
|
||||||
"Please regenerate the code."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Handle errors
|
|
||||||
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
|
||||||
message = BaseMessage(
|
|
||||||
input_message.role_name,
|
|
||||||
input_message.role_type,
|
|
||||||
input_message.meta_dict,
|
|
||||||
content,
|
|
||||||
)
|
|
||||||
return ChatAgentResponse(
|
|
||||||
msgs=[message],
|
|
||||||
terminated=response.terminated,
|
|
||||||
info=response.info,
|
|
||||||
)
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from unstructured.documents.elements import Element
|
|
||||||
|
|
||||||
from camel.agents import ChatAgent
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.prompts import TextPrompt
|
|
||||||
from camel.storages.graph_storages.graph_element import (
|
|
||||||
GraphElement,
|
|
||||||
Node,
|
|
||||||
Relationship,
|
|
||||||
)
|
|
||||||
from camel.types import RoleType
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
text_prompt = """
|
|
||||||
You are tasked with extracting nodes and relationships from given content and
|
|
||||||
structures them into Node and Relationship objects. Here's the outline of what
|
|
||||||
you needs to do:
|
|
||||||
|
|
||||||
Content Extraction:
|
|
||||||
You should be able to process input content and identify entities mentioned
|
|
||||||
within it.
|
|
||||||
Entities can be any noun phrases or concepts that represent distinct entities
|
|
||||||
in the context of the given content.
|
|
||||||
|
|
||||||
Node Extraction:
|
|
||||||
For each identified entity, you should create a Node object.
|
|
||||||
Each Node object should have a unique identifier (id) and a type (type).
|
|
||||||
Additional properties associated with the node can also be extracted and
|
|
||||||
stored.
|
|
||||||
|
|
||||||
Relationship Extraction:
|
|
||||||
You should identify relationships between entities mentioned in the content.
|
|
||||||
For each relationship, create a Relationship object.
|
|
||||||
A Relationship object should have a subject (subj) and an object (obj) which
|
|
||||||
are Node objects representing the entities involved in the relationship.
|
|
||||||
Each relationship should also have a type (type), and additional properties if
|
|
||||||
applicable.
|
|
||||||
|
|
||||||
Output Formatting:
|
|
||||||
The extracted nodes and relationships should be formatted as instances of the
|
|
||||||
provided Node and Relationship classes.
|
|
||||||
Ensure that the extracted data adheres to the structure defined by the classes.
|
|
||||||
Output the structured data in a format that can be easily validated against
|
|
||||||
the provided code.
|
|
||||||
|
|
||||||
Instructions for you:
|
|
||||||
Read the provided content thoroughly.
|
|
||||||
Identify distinct entities mentioned in the content and categorize them as
|
|
||||||
nodes.
|
|
||||||
Determine relationships between these entities and represent them as directed
|
|
||||||
relationships.
|
|
||||||
Provide the extracted nodes and relationships in the specified format below.
|
|
||||||
Example for you:
|
|
||||||
|
|
||||||
Example Content:
|
|
||||||
"John works at XYZ Corporation. He is a software engineer. The company is
|
|
||||||
located in New York City."
|
|
||||||
|
|
||||||
Expected Output:
|
|
||||||
|
|
||||||
Nodes:
|
|
||||||
|
|
||||||
Node(id='John', type='Person')
|
|
||||||
Node(id='XYZ Corporation', type='Organization')
|
|
||||||
Node(id='New York City', type='Location')
|
|
||||||
|
|
||||||
Relationships:
|
|
||||||
|
|
||||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
|
||||||
Corporation', type='Organization'), type='WorksAt')
|
|
||||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
|
||||||
type='Location'), type='ResidesIn')
|
|
||||||
|
|
||||||
===== TASK =====
|
|
||||||
Please extracts nodes and relationships from given content and structures them
|
|
||||||
into Node and Relationship objects.
|
|
||||||
|
|
||||||
{task}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="KnowledgeGraphAgent")
|
|
||||||
class KnowledgeGraphAgent(ChatAgent):
|
|
||||||
r"""An agent that can extract node and relationship information for
|
|
||||||
different entities from given `Element` content.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
|
||||||
relationship information for different entities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Initialize the `KnowledgeGraphAgent`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
"""
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Graphify",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="Your mission is to transform unstructured content "
|
|
||||||
"into structured graph data. Extract nodes and relationships with "
|
|
||||||
"precision, and let the connections unfold. Your graphs will "
|
|
||||||
"illuminate the hidden connections within the chaos of "
|
|
||||||
"information.",
|
|
||||||
)
|
|
||||||
super().__init__(system_message, model=model)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
element: "Element",
|
|
||||||
parse_graph_elements: bool = False,
|
|
||||||
) -> Union[str, GraphElement]:
|
|
||||||
r"""Run the agent to extract node and relationship information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
element (Element): The input element.
|
|
||||||
parse_graph_elements (bool, optional): Whether to parse into
|
|
||||||
`GraphElement`. Defaults to `False`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, GraphElement]: The extracted node and relationship
|
|
||||||
information. If `parse_graph_elements` is `True` then return
|
|
||||||
`GraphElement`, else return `str`.
|
|
||||||
"""
|
|
||||||
self.reset()
|
|
||||||
self.element = element
|
|
||||||
|
|
||||||
knowledge_graph_prompt = TextPrompt(text_prompt)
|
|
||||||
knowledge_graph_generation = knowledge_graph_prompt.format(
|
|
||||||
task=str(element)
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_graph_generation_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Graphify", content=knowledge_graph_generation
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.step(input_message=knowledge_graph_generation_msg)
|
|
||||||
|
|
||||||
content = response.msg.content
|
|
||||||
|
|
||||||
if parse_graph_elements:
|
|
||||||
content = self._parse_graph_elements(content)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
def _validate_node(self, node: Node) -> bool:
|
|
||||||
r"""Validate if the object is a valid Node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (Node): Object to be validated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the object is a valid Node, False otherwise.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
isinstance(node, Node)
|
|
||||||
and isinstance(node.id, (str, int))
|
|
||||||
and isinstance(node.type, str)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_relationship(self, relationship: Relationship) -> bool:
|
|
||||||
r"""Validate if the object is a valid Relationship.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
relationship (Relationship): Object to be validated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the object is a valid Relationship, False otherwise.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
isinstance(relationship, Relationship)
|
|
||||||
and self._validate_node(relationship.subj)
|
|
||||||
and self._validate_node(relationship.obj)
|
|
||||||
and isinstance(relationship.type, str)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
|
||||||
r"""Parses graph elements from given content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_string (str): The input content.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GraphElement: The parsed graph elements.
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Regular expressions to extract nodes and relationships
|
|
||||||
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
|
||||||
rel_pattern = (
|
|
||||||
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
|
||||||
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = {}
|
|
||||||
relationships = []
|
|
||||||
|
|
||||||
# Extract nodes
|
|
||||||
for match in re.finditer(node_pattern, input_string):
|
|
||||||
id, type = match.groups()
|
|
||||||
properties = {'source': 'agent_created'}
|
|
||||||
if id not in nodes:
|
|
||||||
node = Node(id=id, type=type, properties=properties)
|
|
||||||
if self._validate_node(node):
|
|
||||||
nodes[id] = node
|
|
||||||
|
|
||||||
# Extract relationships
|
|
||||||
for match in re.finditer(rel_pattern, input_string):
|
|
||||||
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
|
|
||||||
properties = {'source': 'agent_created'}
|
|
||||||
if subj_id in nodes and obj_id in nodes:
|
|
||||||
subj = nodes[subj_id]
|
|
||||||
obj = nodes[obj_id]
|
|
||||||
relationship = Relationship(
|
|
||||||
subj=subj, obj=obj, type=rel_type, properties=properties
|
|
||||||
)
|
|
||||||
if self._validate_relationship(relationship):
|
|
||||||
relationships.append(relationship)
|
|
||||||
|
|
||||||
return GraphElement(
|
|
||||||
nodes=list(nodes.values()),
|
|
||||||
relationships=relationships,
|
|
||||||
source=self.element,
|
|
||||||
)
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import re
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.prompts import TextPrompt
|
|
||||||
from camel.types import RoleType
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="RoleAssignmentAgent")
|
|
||||||
class RoleAssignmentAgent(ChatAgent):
|
|
||||||
r"""An agent that generates role names based on the task prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
|
||||||
role names.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
) -> None:
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Role Assigner",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You assign roles based on tasks.",
|
|
||||||
)
|
|
||||||
super().__init__(system_message, model=model)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
task_prompt: Union[str, TextPrompt],
|
|
||||||
num_roles: int = 2,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
r"""Generate role names based on the input task prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_prompt (Union[str, TextPrompt]): The prompt
|
|
||||||
for the task based on which the roles are to be generated.
|
|
||||||
num_roles (int, optional): The number of roles to generate.
|
|
||||||
(default: :obj:`2`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, str]: A dictionary mapping role names to their
|
|
||||||
descriptions.
|
|
||||||
"""
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
|
||||||
f"Domain expert {i + 1}: <BLANK>\n"
|
|
||||||
f"Associated competencies, characteristics, duties "
|
|
||||||
f"and workflows: <BLANK>. End."
|
|
||||||
for i in range(num_roles or 0)
|
|
||||||
)
|
|
||||||
role_assignment_generation_prompt = TextPrompt(
|
|
||||||
"You are a role assignment agent, and you're in charge of "
|
|
||||||
+ "recruiting {num_roles} experts for the following task."
|
|
||||||
+ "\n==== TASK =====\n {task}\n\n"
|
|
||||||
+ "Identify the domain experts you'd recruit and detail their "
|
|
||||||
+ "associated competencies, characteristics, duties and workflows "
|
|
||||||
+ "to complete the task.\n "
|
|
||||||
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
|
||||||
+ "ONLY answer the BLANKs.\n"
|
|
||||||
+ expert_prompt
|
|
||||||
)
|
|
||||||
role_assignment_generation = role_assignment_generation_prompt.format(
|
|
||||||
num_roles=num_roles, task=task_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
role_assignment_generation_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Role Assigner", content=role_assignment_generation
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.step(input_message=role_assignment_generation_msg)
|
|
||||||
|
|
||||||
msg = response.msg # type: BaseMessage
|
|
||||||
terminated = response.terminated
|
|
||||||
|
|
||||||
# Distribute the output completions into role names and descriptions
|
|
||||||
role_names = [
|
|
||||||
desc.replace("<|", "").replace("|>", "")
|
|
||||||
for desc in re.findall(
|
|
||||||
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
|
||||||
msg.content,
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
role_descriptions = [
|
|
||||||
desc.replace("<|", "").replace("|>", "")
|
|
||||||
for desc in re.findall(
|
|
||||||
r"Associated competencies, characteristics, "
|
|
||||||
r"duties and workflows: (.+?) End.",
|
|
||||||
msg.content,
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Got None or insufficient information of roles."
|
|
||||||
)
|
|
||||||
if terminated:
|
|
||||||
raise RuntimeError("Role assignment failed.")
|
|
||||||
|
|
||||||
role_descriptions_dict = {
|
|
||||||
role_name: description
|
|
||||||
for role_name, description in zip(role_names, role_descriptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
return role_descriptions_dict
|
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.prompts import TextPrompt
|
|
||||||
from camel.types import RoleType
|
|
||||||
from camel.utils import create_chunks
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="SearchAgent")
|
|
||||||
class SearchAgent(ChatAgent):
|
|
||||||
r"""An agent that summarizes text based on a query and evaluates the
|
|
||||||
relevance of an answer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
) -> None:
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Assistant",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You are a helpful assistant.",
|
|
||||||
)
|
|
||||||
super().__init__(system_message, model=model)
|
|
||||||
|
|
||||||
def summarize_text(self, text: str, query: str) -> str:
|
|
||||||
r"""Summarize the information from the text, base on the query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Text to summarize.
|
|
||||||
query (str): What information you want.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Strings with information.
|
|
||||||
"""
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
summary_prompt = TextPrompt(
|
|
||||||
'''Gather information from this text that relative to the
|
|
||||||
question, but do not directly answer the question.\nquestion:
|
|
||||||
{query}\ntext '''
|
|
||||||
)
|
|
||||||
summary_prompt = summary_prompt.format(query=query)
|
|
||||||
# Max length of each chunk
|
|
||||||
max_len = 3000
|
|
||||||
results = ""
|
|
||||||
chunks = create_chunks(text, max_len)
|
|
||||||
# Summarize
|
|
||||||
for i, chunk in enumerate(chunks, start=1):
|
|
||||||
prompt = summary_prompt + str(i) + ": " + chunk
|
|
||||||
user_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="User",
|
|
||||||
content=prompt,
|
|
||||||
)
|
|
||||||
result = self.step(user_msg).msg.content
|
|
||||||
results += result + "\n"
|
|
||||||
|
|
||||||
# Final summarization
|
|
||||||
final_prompt = TextPrompt(
|
|
||||||
'''Here are some summarized texts which split from one text. Using
|
|
||||||
the information to answer the question. If can't find the answer,
|
|
||||||
you must answer "I can not find the answer to the query" and
|
|
||||||
explain why.\n Query:\n{query}.\n\nText:\n'''
|
|
||||||
)
|
|
||||||
final_prompt = final_prompt.format(query=query)
|
|
||||||
prompt = final_prompt + results
|
|
||||||
|
|
||||||
user_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="User",
|
|
||||||
content=prompt,
|
|
||||||
)
|
|
||||||
response = self.step(user_msg).msg.content
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def continue_search(self, query: str, answer: str) -> bool:
|
|
||||||
r"""Ask whether to continue search or not based on the provided answer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The question.
|
|
||||||
answer (str): The answer to the question.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: `True` if the user want to continue search, `False`
|
|
||||||
otherwise.
|
|
||||||
"""
|
|
||||||
prompt = TextPrompt(
|
|
||||||
"Do you think the ANSWER can answer the QUERY? "
|
|
||||||
"Use only 'yes' or 'no' to answer.\n"
|
|
||||||
"===== QUERY =====\n{query}\n\n"
|
|
||||||
"===== ANSWER =====\n{answer}"
|
|
||||||
)
|
|
||||||
prompt = prompt.format(query=query, answer=answer)
|
|
||||||
user_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="User",
|
|
||||||
content=prompt,
|
|
||||||
)
|
|
||||||
response = self.step(user_msg).msg.content
|
|
||||||
if "yes" in str(response).lower():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
@@ -1,410 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from camel.agents.chat_agent import ChatAgent
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
|
||||||
from camel.types import RoleType, TaskType
|
|
||||||
from camel.utils import get_task_list
|
|
||||||
|
|
||||||
# AgentOps decorator setting
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import track_agent
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
from camel.utils import track_agent
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="TaskSpecifyAgent")
|
|
||||||
class TaskSpecifyAgent(ChatAgent):
|
|
||||||
r"""An agent that specifies a given task prompt by prompting the user to
|
|
||||||
provide more details.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
|
||||||
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
task_type (TaskType, optional): The type of task for which to generate
|
|
||||||
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
|
||||||
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
|
||||||
specifying the task. (default: :obj:`None`)
|
|
||||||
word_limit (int, optional): The word limit for the task prompt.
|
|
||||||
(default: :obj:`50`)
|
|
||||||
output_language (str, optional): The language to be output by the
|
|
||||||
agent. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_WORD_LIMIT = 50
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
|
||||||
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
|
||||||
word_limit: int = DEFAULT_WORD_LIMIT,
|
|
||||||
output_language: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.task_specify_prompt: Union[str, TextPrompt]
|
|
||||||
if task_specify_prompt is None:
|
|
||||||
task_specify_prompt_template = (
|
|
||||||
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.task_specify_prompt = task_specify_prompt_template.format(
|
|
||||||
word_limit=word_limit
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
|
||||||
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Task Specifier",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You can make a task more specific.",
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
system_message,
|
|
||||||
model=model,
|
|
||||||
output_language=output_language,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
task_prompt: Union[str, TextPrompt],
|
|
||||||
meta_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> TextPrompt:
|
|
||||||
r"""Specify the given task prompt by providing more details.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_prompt (Union[str, TextPrompt]): The original task
|
|
||||||
prompt.
|
|
||||||
meta_dict (Dict[str, Any], optional): A dictionary containing
|
|
||||||
additional information to include in the prompt.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TextPrompt: The specified task prompt.
|
|
||||||
"""
|
|
||||||
self.reset()
|
|
||||||
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
|
||||||
|
|
||||||
if meta_dict is not None:
|
|
||||||
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
|
||||||
task_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Task Specifier", content=task_specify_prompt
|
|
||||||
)
|
|
||||||
specifier_response = self.step(task_msg)
|
|
||||||
|
|
||||||
if specifier_response.terminated:
|
|
||||||
raise RuntimeError("Task specification failed.")
|
|
||||||
if len(specifier_response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got no specification message.")
|
|
||||||
|
|
||||||
specified_task_msg = specifier_response.msgs[0]
|
|
||||||
|
|
||||||
return TextPrompt(specified_task_msg.content)
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="TaskPlannerAgent")
|
|
||||||
class TaskPlannerAgent(ChatAgent):
|
|
||||||
r"""An agent that helps divide a task into subtasks based on the input
|
|
||||||
task prompt.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
|
||||||
the task into subtasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (BaseModelBackend, optional): The model backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
output_language (str, optional): The language to be output by the
|
|
||||||
agent. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
output_language: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.task_planner_prompt = TextPrompt(
|
|
||||||
"Divide this task into subtasks: {task}. Be concise."
|
|
||||||
)
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Task Planner",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You are a helpful task planner.",
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
system_message,
|
|
||||||
model=model,
|
|
||||||
output_language=output_language,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
task_prompt: Union[str, TextPrompt],
|
|
||||||
) -> TextPrompt:
|
|
||||||
r"""Generate subtasks based on the input task prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
|
||||||
be divided into subtasks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TextPrompt: A prompt for the subtasks generated by the agent.
|
|
||||||
"""
|
|
||||||
# TODO: Maybe include roles information.
|
|
||||||
self.reset()
|
|
||||||
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
|
||||||
|
|
||||||
task_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Task Planner", content=task_planner_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
task_response = self.step(task_msg)
|
|
||||||
|
|
||||||
if task_response.terminated:
|
|
||||||
raise RuntimeError("Task planning failed.")
|
|
||||||
if len(task_response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got no task planning message.")
|
|
||||||
|
|
||||||
sub_tasks_msg = task_response.msgs[0]
|
|
||||||
return TextPrompt(sub_tasks_msg.content)
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="TaskCreationAgent")
|
|
||||||
class TaskCreationAgent(ChatAgent):
|
|
||||||
r"""An agent that helps create new tasks based on the objective
|
|
||||||
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
|
||||||
it's still a task planner, but it has more context information
|
|
||||||
like last task and incomplete task list. Modified from
|
|
||||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_creation_prompt (TextPrompt): A prompt for the agent to
|
|
||||||
create new tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_name (str): The role name of the Agent to create the task.
|
|
||||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
|
||||||
perform the task.
|
|
||||||
model (BaseModelBackend, optional): The LLM backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
output_language (str, optional): The language to be output by the
|
|
||||||
agent. (default: :obj:`None`)
|
|
||||||
message_window_size (int, optional): The maximum number of previous
|
|
||||||
messages to include in the context window. If `None`, no windowing
|
|
||||||
is performed. (default: :obj:`None`)
|
|
||||||
max_task_num (int, optional): The maximum number of planned
|
|
||||||
tasks in one round. (default: :obj:3)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
role_name: str,
|
|
||||||
objective: Union[str, TextPrompt],
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
output_language: Optional[str] = None,
|
|
||||||
message_window_size: Optional[int] = None,
|
|
||||||
max_task_num: Optional[int] = 3,
|
|
||||||
) -> None:
|
|
||||||
task_creation_prompt = TextPrompt(
|
|
||||||
"""Create new a task with the following objective: {objective}.
|
|
||||||
Never forget you are a Task Creator of {role_name}.
|
|
||||||
You must instruct me based on my expertise and your needs to solve the task.
|
|
||||||
You should consider past solved tasks and in-progress tasks: {task_list}.
|
|
||||||
The new created tasks must not overlap with these past tasks.
|
|
||||||
The result must be a numbered list in the format:
|
|
||||||
|
|
||||||
#. First Task
|
|
||||||
#. Second Task
|
|
||||||
#. Third Task
|
|
||||||
|
|
||||||
You can only give me up to {max_task_num} tasks at a time. \
|
|
||||||
Each task should be concise, concrete and doable for a {role_name}.
|
|
||||||
You should make task plan and not ask me questions.
|
|
||||||
If you think no new tasks are needed right now, write "No tasks to add."
|
|
||||||
Now start to give me new tasks one by one. No more than three tasks.
|
|
||||||
Be concrete.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.task_creation_prompt = task_creation_prompt.format(
|
|
||||||
objective=objective, role_name=role_name, max_task_num=max_task_num
|
|
||||||
)
|
|
||||||
self.objective = objective
|
|
||||||
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Task Creator",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You are a helpful task creator.",
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
system_message,
|
|
||||||
model=model,
|
|
||||||
output_language=output_language,
|
|
||||||
message_window_size=message_window_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
task_list: List[str],
|
|
||||||
) -> List[str]:
|
|
||||||
r"""Generate subtasks based on the previous task results and
|
|
||||||
incomplete task list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_list (List[str]): The completed or in-progress
|
|
||||||
tasks which should not overlap with new created tasks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: The new task list generated by the Agent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(task_list) > 0:
|
|
||||||
task_creation_prompt = self.task_creation_prompt.format(
|
|
||||||
task_list=task_list
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
task_creation_prompt = self.task_creation_prompt.format(
|
|
||||||
task_list=""
|
|
||||||
)
|
|
||||||
|
|
||||||
task_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Task Creator", content=task_creation_prompt
|
|
||||||
)
|
|
||||||
task_response = self.step(task_msg)
|
|
||||||
|
|
||||||
if task_response.terminated:
|
|
||||||
raise RuntimeError("Task creation failed.")
|
|
||||||
if len(task_response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got no task creation message.")
|
|
||||||
|
|
||||||
sub_tasks_msg = task_response.msgs[0]
|
|
||||||
return get_task_list(sub_tasks_msg.content)
|
|
||||||
|
|
||||||
|
|
||||||
@track_agent(name="TaskPrioritizationAgent")
|
|
||||||
class TaskPrioritizationAgent(ChatAgent):
|
|
||||||
r"""An agent that helps re-prioritize the task list and
|
|
||||||
returns numbered prioritized list. Modified from
|
|
||||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
|
||||||
prioritize tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
|
||||||
perform the task.
|
|
||||||
model (BaseModelBackend, optional): The LLM backend to use for
|
|
||||||
generating responses. (default: :obj:`OpenAIModel` with
|
|
||||||
`GPT_4O_MINI`)
|
|
||||||
output_language (str, optional): The language to be output by the
|
|
||||||
agent. (default: :obj:`None`)
|
|
||||||
message_window_size (int, optional): The maximum number of previous
|
|
||||||
messages to include in the context window. If `None`, no windowing
|
|
||||||
is performed. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
objective: Union[str, TextPrompt],
|
|
||||||
model: Optional[BaseModelBackend] = None,
|
|
||||||
output_language: Optional[str] = None,
|
|
||||||
message_window_size: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
task_prioritization_prompt = TextPrompt(
|
|
||||||
"""Prioritize the following tasks : {task_list}.
|
|
||||||
Consider the ultimate objective of you: {objective}.
|
|
||||||
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
|
||||||
tasks are those that act as pre-requisites or are more essential for meeting \
|
|
||||||
the objective. Return one task per line in your response.
|
|
||||||
Do not remove or modify any tasks.
|
|
||||||
The result must be a numbered list in the format:
|
|
||||||
|
|
||||||
#. First task
|
|
||||||
#. Second task
|
|
||||||
|
|
||||||
The entries must be consecutively numbered, starting with 1.
|
|
||||||
The number of each entry must be followed by a period.
|
|
||||||
Do not include any headers before your ranked list or follow your list \
|
|
||||||
with any other output."""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
|
||||||
objective=objective
|
|
||||||
)
|
|
||||||
self.objective = objective
|
|
||||||
|
|
||||||
system_message = BaseMessage(
|
|
||||||
role_name="Task Prioritizer",
|
|
||||||
role_type=RoleType.ASSISTANT,
|
|
||||||
meta_dict=None,
|
|
||||||
content="You are a helpful task prioritizer.",
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
system_message,
|
|
||||||
model=model,
|
|
||||||
output_language=output_language,
|
|
||||||
message_window_size=message_window_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
task_list: List[str],
|
|
||||||
) -> List[str]:
|
|
||||||
r"""Prioritize the task list given the agent objective.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_list (List[str]): The unprioritized tasks of agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: The new prioritized task list generated by the Agent.
|
|
||||||
"""
|
|
||||||
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
|
||||||
task_list=task_list
|
|
||||||
)
|
|
||||||
|
|
||||||
task_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="Task Prioritizer", content=task_prioritization_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
task_response = self.step(task_msg)
|
|
||||||
|
|
||||||
if task_response.terminated:
|
|
||||||
raise RuntimeError("Task prioritization failed.")
|
|
||||||
if len(task_response.msgs) == 0:
|
|
||||||
raise RuntimeError("Got no task prioritization message.")
|
|
||||||
|
|
||||||
sub_tasks_msg = task_response.msgs[0]
|
|
||||||
return get_task_list(sub_tasks_msg.content)
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .base import BaseToolAgent
|
|
||||||
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseToolAgent',
|
|
||||||
'HuggingFaceToolAgent',
|
|
||||||
]
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from camel.agents import BaseAgent
|
|
||||||
|
|
||||||
|
|
||||||
class BaseToolAgent(BaseAgent):
|
|
||||||
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
|
||||||
description.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the tool agent.
|
|
||||||
description (str): The description of the tool agent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name: str, description: str) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
r"""Resets the agent to its initial state."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def step(self) -> None:
|
|
||||||
r"""Performs a single step of the agent."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{self.name}: {self.description}"
|
|
||||||
@@ -1,206 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from camel.agents.tool_agents.base import BaseToolAgent
|
|
||||||
|
|
||||||
|
|
||||||
# flake8: noqa :E501
|
|
||||||
class HuggingFaceToolAgent(BaseToolAgent):
|
|
||||||
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
|
||||||
around agents from the `transformers` library. For more information
|
|
||||||
about the available models, please see the `transformers` documentation
|
|
||||||
at https://huggingface.co/docs/transformers/transformers_agents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the agent.
|
|
||||||
*args (Any): Additional positional arguments to pass to the underlying
|
|
||||||
Agent class.
|
|
||||||
remote (bool, optional): Flag indicating whether to run the agent
|
|
||||||
remotely. (default: :obj:`True`)
|
|
||||||
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
|
||||||
Agent class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
*args: Any,
|
|
||||||
remote: bool = True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
# TODO: Support other tool agents
|
|
||||||
import transformers
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
if version.parse(transformers.__version__) < version.parse(
|
|
||||||
"4.31.0"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"The version of \"transformers\" package should >= 4.31.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers.tools import OpenAiAgent
|
|
||||||
from transformers.tools.agent_types import AgentImage
|
|
||||||
except (ImportError, ValueError):
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import transformers tool agents. "
|
|
||||||
"Please setup the environment with "
|
|
||||||
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
|
||||||
)
|
|
||||||
self.agent_image_type = AgentImage
|
|
||||||
self.agent = OpenAiAgent(*args, **kwargs)
|
|
||||||
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
|
||||||
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
|
||||||
- Text question answering: given a long text and a question, answer the question in the text
|
|
||||||
- Unconditional image captioning: Caption the image!
|
|
||||||
- Image question answering: given an image, answer a question on this image
|
|
||||||
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
|
||||||
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
|
||||||
- Text to speech: convert text to speech
|
|
||||||
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
|
||||||
- Text summarization: summarize a long text in one or a few sentences
|
|
||||||
- Translation: translate the text into a given language
|
|
||||||
- Text downloading: to download a text from a web URL
|
|
||||||
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
|
||||||
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
|
||||||
- Text to video: generate a small video according to a prompt
|
|
||||||
|
|
||||||
Here are some python code examples of what you can do with this agent:
|
|
||||||
|
|
||||||
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
|
||||||
```
|
|
||||||
# Text to image
|
|
||||||
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
|
||||||
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
|
||||||
|
|
||||||
# Text to image -> Image transformation
|
|
||||||
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
|
||||||
sea_add_island_image.save("./sea_add_island_image.png")
|
|
||||||
|
|
||||||
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
|
||||||
# you can do so by specifying variables that you would like the agent to use. For example,
|
|
||||||
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
|
||||||
picture = {name}.step("Generate a picture of rivers and lakes.")
|
|
||||||
picture.save("./picture.png")
|
|
||||||
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
|
||||||
updated_picture.save("./updated_picture.png")
|
|
||||||
|
|
||||||
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
|
||||||
capybara_sea_image.save("./capybara_sea_image.png")
|
|
||||||
|
|
||||||
# Document question answering
|
|
||||||
answer = {name}.step(
|
|
||||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
|
||||||
document=document,
|
|
||||||
)
|
|
||||||
print(answer)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
|
||||||
boat_image = {name}.step("Generate an image of a boat in the water")
|
|
||||||
boat_image.save("./boat_image.png")
|
|
||||||
|
|
||||||
# Unconditional image captioning
|
|
||||||
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
|
||||||
print(boat_image_caption)
|
|
||||||
|
|
||||||
# Text to image -> Unconditional image captioning -> Text to speech
|
|
||||||
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
|
||||||
|
|
||||||
# Text downloading
|
|
||||||
document = {name}.step("Download the text from http://hf.co")
|
|
||||||
print(document)
|
|
||||||
|
|
||||||
# Text summarization
|
|
||||||
summary = {name}.step("Summarize the following text: `document`", document=document)
|
|
||||||
print(summary)
|
|
||||||
|
|
||||||
# Text downloading -> Text summarization -> Text to speech
|
|
||||||
audio = {name}.step("Read out loud the summary of http://hf.co")
|
|
||||||
```
|
|
||||||
|
|
||||||
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
|
||||||
```
|
|
||||||
# Clean the chat history
|
|
||||||
{name}.reset()
|
|
||||||
|
|
||||||
# Text to image
|
|
||||||
capybara_image = {name}.chat("Show me an an image of a capybara")
|
|
||||||
capybara_image.save("./capybara_image.png")
|
|
||||||
|
|
||||||
# Image transformation
|
|
||||||
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
|
||||||
transformed_capybara_image.save("./transformed_capybara_image.png")
|
|
||||||
|
|
||||||
# Image segmentation
|
|
||||||
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
|
||||||
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
super(HuggingFaceToolAgent, self).__init__(name, description)
|
|
||||||
self.remote = remote
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
r"""Resets the chat history of the agent."""
|
|
||||||
self.agent.prepare_for_new_chat()
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
remote: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
r"""Runs the agent in single execution mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args (Any): Positional arguments to pass to the agent.
|
|
||||||
remote (bool, optional): Flag indicating whether to run the agent
|
|
||||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
|
||||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The response from the agent.
|
|
||||||
"""
|
|
||||||
if remote is None:
|
|
||||||
remote = self.remote
|
|
||||||
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
|
||||||
if isinstance(agent_output, self.agent_image_type):
|
|
||||||
agent_output = agent_output.to_raw()
|
|
||||||
return agent_output
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
remote: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
r"""Runs the agent in a chat conversation mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args (Any): Positional arguments to pass to the agent.
|
|
||||||
remote (bool, optional): Flag indicating whether to run the agent
|
|
||||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
|
||||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The response from the agent.
|
|
||||||
"""
|
|
||||||
if remote is None:
|
|
||||||
remote = self.remote
|
|
||||||
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
|
||||||
if isinstance(agent_output, self.agent_image_type):
|
|
||||||
agent_output = agent_output.to_raw()
|
|
||||||
return agent_output
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .base import BaseBenchmark
|
|
||||||
|
|
||||||
__all__ = ["BaseBenchmark"]
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
|
|
||||||
from camel.agents import ChatAgent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBenchmark(ABC):
|
|
||||||
r"""Base class for benchmarks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name (str): Name of the benchmark.
|
|
||||||
data_dir (str): Path to the data directory.
|
|
||||||
save_to (str): Path to save the results.
|
|
||||||
processes (int): Number of processes to use for parallel
|
|
||||||
processing. :(default: :obj:`1`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
|
||||||
):
|
|
||||||
r"""Initialize the benchmark.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name of the benchmark.
|
|
||||||
data_dir (str): Path to the data directory.
|
|
||||||
save_to (str): Path to save the results.
|
|
||||||
processes (int): Number of processes to use for parallel
|
|
||||||
processing. :(default: :obj:`1`)
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.name = name
|
|
||||||
self.data_dir = Path(data_dir)
|
|
||||||
self.processes = processes
|
|
||||||
self.save_to = save_to
|
|
||||||
if not self.data_dir.exists():
|
|
||||||
logger.info(
|
|
||||||
f"Data directory {data_dir} does not exist. Creating it."
|
|
||||||
)
|
|
||||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not self.data_dir.is_dir():
|
|
||||||
raise NotADirectoryError(
|
|
||||||
f"Data directory {data_dir} is not a directory"
|
|
||||||
)
|
|
||||||
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
|
||||||
self._results: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def download(self) -> "BaseBenchmark":
|
|
||||||
r"""Download the benchmark data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseBenchmark: The benchmark instance.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
|
||||||
r"""Load the benchmark data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_download (bool): Whether to force download the data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseBenchmark: The benchmark instance.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def train(self) -> List[Dict[str, Any]]:
|
|
||||||
r"""Get the training data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: The training data.
|
|
||||||
"""
|
|
||||||
if not self._data:
|
|
||||||
logger.info("Data not loaded. Loading data.")
|
|
||||||
self.load()
|
|
||||||
return self._data["train"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def valid(self) -> List[Dict[str, Any]]:
|
|
||||||
r"""Get the validation data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: The validation data.
|
|
||||||
"""
|
|
||||||
if not self._data:
|
|
||||||
logger.info("Data not loaded. Loading data.")
|
|
||||||
self.load()
|
|
||||||
return self._data["valid"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def test(self) -> List[Dict[str, Any]]:
|
|
||||||
r"""Get the test data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: The test data.
|
|
||||||
"""
|
|
||||||
if not self._data:
|
|
||||||
logger.info("Data not loaded. Loading data.")
|
|
||||||
self.load()
|
|
||||||
return self._data["test"]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
agent: ChatAgent,
|
|
||||||
on: Literal["train", "valid", "test"],
|
|
||||||
randomize: bool = False,
|
|
||||||
subset: Optional[int] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> "BaseBenchmark":
|
|
||||||
r"""Run the benchmark.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent (ChatAgent): The chat agent.
|
|
||||||
on (str): The data split to run the benchmark on.
|
|
||||||
randomize (bool): Whether to randomize the data.
|
|
||||||
subset (int): The subset of the data to run the benchmark on.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseBenchmark: The benchmark instance.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def results(self) -> List[Dict[str, Any]]:
|
|
||||||
r"""Get the results.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: The results.
|
|
||||||
"""
|
|
||||||
return self._results
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .discord_app import DiscordApp
|
|
||||||
from .slack.models import (
|
|
||||||
SlackAppMentionEventBody,
|
|
||||||
SlackAppMentionEventProfile,
|
|
||||||
SlackAuthProfile,
|
|
||||||
SlackEventBody,
|
|
||||||
SlackEventProfile,
|
|
||||||
)
|
|
||||||
from .slack.slack_app import SlackApp
|
|
||||||
from .telegram_bot import TelegramBot
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'DiscordApp',
|
|
||||||
'SlackApp',
|
|
||||||
'SlackAppMentionEventBody',
|
|
||||||
'SlackAppMentionEventProfile',
|
|
||||||
'SlackAuthProfile',
|
|
||||||
'SlackEventBody',
|
|
||||||
'SlackEventProfile',
|
|
||||||
'TelegramBot',
|
|
||||||
]
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from camel.utils import dependencies_required
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from discord import Message
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordApp:
|
|
||||||
r"""A class representing a Discord app that uses the `discord.py` library
|
|
||||||
to interact with Discord servers.
|
|
||||||
|
|
||||||
This bot can respond to messages in specific channels and only reacts to
|
|
||||||
messages that mention the bot.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
|
|
||||||
provided, the bot will only respond to messages in these channels.
|
|
||||||
token (Optional[str]): The Discord bot token used for authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dependencies_required('discord')
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channel_ids: Optional[List[int]] = None,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Initialize the DiscordApp instance by setting up the Discord client
|
|
||||||
and event handlers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
|
|
||||||
The bot will only respond to messages in these channels if
|
|
||||||
provided.
|
|
||||||
token (Optional[str]): The Discord bot token for authentication.
|
|
||||||
If not provided, the token will be retrieved from the
|
|
||||||
environment variable `DISCORD_TOKEN`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the `DISCORD_TOKEN` is not found in environment
|
|
||||||
variables.
|
|
||||||
"""
|
|
||||||
self.token = token or os.getenv('DISCORD_TOKEN')
|
|
||||||
self.channel_ids = channel_ids
|
|
||||||
|
|
||||||
if not self.token:
|
|
||||||
raise ValueError(
|
|
||||||
"`DISCORD_TOKEN` not found in environment variables. Get it"
|
|
||||||
" here: `https://discord.com/developers/applications`."
|
|
||||||
)
|
|
||||||
|
|
||||||
import discord
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
self._client = discord.Client(intents=intents)
|
|
||||||
|
|
||||||
# Register event handlers
|
|
||||||
self._client.event(self.on_ready)
|
|
||||||
self._client.event(self.on_message)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
r"""Asynchronously start the Discord bot using its token.
|
|
||||||
|
|
||||||
This method starts the bot and logs into Discord asynchronously using
|
|
||||||
the provided token. It should be awaited when used in an async
|
|
||||||
environment.
|
|
||||||
"""
|
|
||||||
await self._client.start(self.token)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
r"""Start the Discord bot using its token.
|
|
||||||
|
|
||||||
This method starts the bot and logs into Discord synchronously using
|
|
||||||
the provided token. It blocks execution and keeps the bot running.
|
|
||||||
"""
|
|
||||||
self._client.run(self.token) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
async def on_ready(self) -> None:
|
|
||||||
r"""Event handler that is called when the bot has successfully
|
|
||||||
connected to the Discord server.
|
|
||||||
|
|
||||||
When the bot is ready and logged into Discord, it prints a message
|
|
||||||
displaying the bot's username.
|
|
||||||
"""
|
|
||||||
logger.info(f'We have logged in as {self._client.user}')
|
|
||||||
|
|
||||||
async def on_message(self, message: 'Message') -> None:
|
|
||||||
r"""Event handler for processing incoming messages.
|
|
||||||
|
|
||||||
This method is called whenever a new message is received by the bot. It
|
|
||||||
will ignore messages sent by the bot itself, only respond to messages
|
|
||||||
in allowed channels (if specified), and only to messages that mention
|
|
||||||
the bot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (discord.Message): The message object received from
|
|
||||||
Discord.
|
|
||||||
"""
|
|
||||||
# If the message author is the bot itself,
|
|
||||||
# do not respond to this message
|
|
||||||
if message.author == self._client.user:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If allowed channel IDs are provided,
|
|
||||||
# only respond to messages in those channels
|
|
||||||
if self.channel_ids and message.channel.id not in self.channel_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only respond to messages that mention the bot
|
|
||||||
if not self._client.user or not self._client.user.mentioned_in(
|
|
||||||
message
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Received message: {message.content}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self):
|
|
||||||
return self._client
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .models import (
|
|
||||||
SlackAppMentionEventBody,
|
|
||||||
SlackAppMentionEventProfile,
|
|
||||||
SlackAuthProfile,
|
|
||||||
SlackEventBody,
|
|
||||||
SlackEventProfile,
|
|
||||||
)
|
|
||||||
from .slack_app import SlackApp
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'SlackApp',
|
|
||||||
'SlackAppMentionEventBody',
|
|
||||||
'SlackAppMentionEventProfile',
|
|
||||||
'SlackAuthProfile',
|
|
||||||
'SlackEventBody',
|
|
||||||
'SlackEventProfile',
|
|
||||||
]
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class SlackAuthProfile(BaseModel):
|
|
||||||
r"""Represents the authorization profile within a Slack event.
|
|
||||||
|
|
||||||
Events will contain a single, compact authorizations field that shows one
|
|
||||||
installation of your app that the event is visible to.
|
|
||||||
In other words, lists of authorizations will be truncated to one element.
|
|
||||||
|
|
||||||
If there's more than one installing party that your app is keeping track
|
|
||||||
of, it's best not to rely on the single party listed in authorizations to
|
|
||||||
be any particular one.
|
|
||||||
|
|
||||||
To get a full list of who can see events, call the apps.event.
|
|
||||||
authorizations.list method after obtaining an app-level token. Read more on
|
|
||||||
the changes here; they have taken effect for existing apps as of
|
|
||||||
February 24, 2021.
|
|
||||||
|
|
||||||
References:
|
|
||||||
|
|
||||||
- https://api.slack.com/apis/events-api#authorizations
|
|
||||||
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
|
|
||||||
"""
|
|
||||||
|
|
||||||
enterprise_id: Optional[str] = None
|
|
||||||
"""The ID of the enterprise associated with the authorization."""
|
|
||||||
|
|
||||||
team_id: str
|
|
||||||
"""The ID of the team associated with the authorization."""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
"""The ID of the user associated with the authorization."""
|
|
||||||
|
|
||||||
is_bot: bool
|
|
||||||
"""Whether the authorized user is a bot."""
|
|
||||||
|
|
||||||
is_enterprise_install: bool
|
|
||||||
"""Whether the authorization is for an enterprise installation."""
|
|
||||||
|
|
||||||
|
|
||||||
class SlackEventProfile(BaseModel):
|
|
||||||
r"""Represents the detailed profile of a Slack event, including user,
|
|
||||||
message, and context data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user: str
|
|
||||||
"""The ID of the user associated with the event."""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
"""The type of the event (e.g., 'message')."""
|
|
||||||
|
|
||||||
ts: str
|
|
||||||
"""A timestamp representing when the event was triggered."""
|
|
||||||
|
|
||||||
thread_ts: Optional[str] = None
|
|
||||||
"""The timestamp of the parent message in a thread."""
|
|
||||||
|
|
||||||
client_msg_id: str
|
|
||||||
"""A unique ID generated by the client for the message (if available)."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""The message content text."""
|
|
||||||
|
|
||||||
team: str
|
|
||||||
"""The ID of the team that the event is associated with."""
|
|
||||||
|
|
||||||
blocks: list
|
|
||||||
"""The list of message blocks, providing structured information."""
|
|
||||||
|
|
||||||
channel: str
|
|
||||||
"""The ID of the Slack channel where the event happened."""
|
|
||||||
|
|
||||||
event_ts: str
|
|
||||||
"""The event-specific timestamp when it occurred."""
|
|
||||||
|
|
||||||
channel_type: Optional[str]
|
|
||||||
"""The type of Slack channel (e.g., 'channel', 'im')."""
|
|
||||||
|
|
||||||
|
|
||||||
class SlackEventBody(BaseModel):
|
|
||||||
r"""Represents the entire body of a Slack event, including the event
|
|
||||||
profile, authorization, and context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
token: str
|
|
||||||
"""The token to verify the source of the event."""
|
|
||||||
|
|
||||||
team_id: str
|
|
||||||
"""The ID of the team where the event is happening."""
|
|
||||||
|
|
||||||
context_team_id: Optional[str]
|
|
||||||
"""The team ID for the shared channel context, if applicable."""
|
|
||||||
|
|
||||||
context_enterprise_id: Optional[str] = None
|
|
||||||
"""The enterprise ID for the shared channel context, if applicable."""
|
|
||||||
|
|
||||||
api_app_id: str
|
|
||||||
"""The unique identifier for the Slack app that received the event."""
|
|
||||||
|
|
||||||
event: SlackEventProfile
|
|
||||||
"""A detailed profile of the event"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
"""The overall type of event received (e.g., 'event_callback')."""
|
|
||||||
|
|
||||||
event_id: str
|
|
||||||
"""A unique identifier assigned to this event by Slack."""
|
|
||||||
|
|
||||||
event_time: int
|
|
||||||
"""The timestamp (in seconds) representing when the event was triggered."""
|
|
||||||
|
|
||||||
authorizations: Optional[list[SlackAuthProfile]] = None
|
|
||||||
"""An optional list of authorizations that describe which installation can
|
|
||||||
see the event."""
|
|
||||||
|
|
||||||
is_ext_shared_channel: bool
|
|
||||||
"""Indicates if the event is part of a shared channel between different
|
|
||||||
organizations."""
|
|
||||||
|
|
||||||
event_context: str
|
|
||||||
"""A unique string representing the context of the event."""
|
|
||||||
|
|
||||||
|
|
||||||
class SlackAppMentionEventProfile(SlackEventProfile):
|
|
||||||
r"""Represents the detailed profile of a Slack event where the app was
|
|
||||||
mentioned in a message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
channel_type: Optional[str] = None
|
|
||||||
"""The type of Slack channel. it's None for app mentions."""
|
|
||||||
|
|
||||||
|
|
||||||
class SlackAppMentionEventBody(SlackEventBody):
|
|
||||||
r"""Represents the entire body of a Slack event where the app was mentioned
|
|
||||||
in a message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
context_team_id: Optional[str] = None
|
|
||||||
"""A detailed profile of the event. it's None for app mentions."""
|
|
||||||
|
|
||||||
event: SlackAppMentionEventProfile
|
|
||||||
"""A detailed profile of the event"""
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
||||||
|
|
||||||
from slack_sdk.oauth.installation_store.async_installation_store import (
|
|
||||||
AsyncInstallationStore,
|
|
||||||
)
|
|
||||||
from starlette import requests, responses
|
|
||||||
|
|
||||||
from camel.bots.slack.models import (
|
|
||||||
SlackAppMentionEventBody,
|
|
||||||
SlackAppMentionEventProfile,
|
|
||||||
SlackEventBody,
|
|
||||||
SlackEventProfile,
|
|
||||||
)
|
|
||||||
from camel.utils import dependencies_required
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from slack_bolt.context.async_context import AsyncBoltContext
|
|
||||||
from slack_bolt.context.say.async_say import AsyncSay
|
|
||||||
from slack_sdk.web.async_client import AsyncWebClient
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SlackApp:
|
|
||||||
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
|
|
||||||
|
|
||||||
This class is responsible for initializing and managing the Slack
|
|
||||||
application by setting up event handlers, running the app server, and
|
|
||||||
handling events such as messages and mentions from Slack.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (Optional[str]): Slack API token for authentication.
|
|
||||||
scopes (Optional[str]): Slack app scopes for permissions.
|
|
||||||
signing_secret (Optional[str]): Signing secret for verifying Slack
|
|
||||||
requests.
|
|
||||||
client_id (Optional[str]): Slack app client ID.
|
|
||||||
client_secret (Optional[str]): Slack app client secret.
|
|
||||||
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
|
|
||||||
"/slack/oauth_redirect".
|
|
||||||
installation_store (Optional[AsyncInstallationStore]): The installation
|
|
||||||
store for handling OAuth installations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dependencies_required('slack_bolt')
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
scopes: Optional[str] = None,
|
|
||||||
signing_secret: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
client_secret: Optional[str] = None,
|
|
||||||
redirect_uri_path: str = "/slack/oauth_redirect",
|
|
||||||
installation_store: Optional[AsyncInstallationStore] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
|
|
||||||
and configuring event handlers and OAuth settings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (Optional[str]): The Slack API token.
|
|
||||||
scopes (Optional[str]): The scopes for Slack app permissions.
|
|
||||||
signing_secret (Optional[str]): The signing secret for verifying
|
|
||||||
requests.
|
|
||||||
client_id (Optional[str]): The Slack app client ID.
|
|
||||||
client_secret (Optional[str]): The Slack app client secret.
|
|
||||||
redirect_uri_path (str): The URI path for handling OAuth redirects
|
|
||||||
(default is "/slack/oauth_redirect").
|
|
||||||
installation_store (Optional[AsyncInstallationStore]): An optional
|
|
||||||
installation store for OAuth installations.
|
|
||||||
"""
|
|
||||||
from slack_bolt.adapter.starlette.async_handler import (
|
|
||||||
AsyncSlackRequestHandler,
|
|
||||||
)
|
|
||||||
from slack_bolt.app.async_app import AsyncApp
|
|
||||||
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
|
|
||||||
|
|
||||||
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
|
|
||||||
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
|
|
||||||
self.signing_secret: Optional[str] = signing_secret or os.getenv(
|
|
||||||
"SLACK_SIGNING_SECRET"
|
|
||||||
)
|
|
||||||
self.client_id: Optional[str] = client_id or os.getenv(
|
|
||||||
"SLACK_CLIENT_ID"
|
|
||||||
)
|
|
||||||
self.client_secret: Optional[str] = client_secret or os.getenv(
|
|
||||||
"SLACK_CLIENT_SECRET"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all([self.token, self.scopes, self.signing_secret]):
|
|
||||||
raise ValueError(
|
|
||||||
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
|
|
||||||
"environment variables must be set. Get it here: "
|
|
||||||
"`https://api.slack.com/apps`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup OAuth settings if client ID and secret are provided
|
|
||||||
if self.client_id and self.client_secret:
|
|
||||||
self._app = AsyncApp(
|
|
||||||
oauth_settings=AsyncOAuthSettings(
|
|
||||||
client_id=self.client_id,
|
|
||||||
client_secret=self.client_secret,
|
|
||||||
scopes=self.scopes,
|
|
||||||
redirect_uri_path=redirect_uri_path,
|
|
||||||
),
|
|
||||||
logger=logger,
|
|
||||||
signing_secret=self.signing_secret,
|
|
||||||
installation_store=installation_store,
|
|
||||||
token=self.token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Initialize Slack Bolt AsyncApp with settings
|
|
||||||
self._app = AsyncApp(
|
|
||||||
logger=logger,
|
|
||||||
signing_secret=self.signing_secret,
|
|
||||||
installation_store=installation_store,
|
|
||||||
token=self.token,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._handler = AsyncSlackRequestHandler(self._app)
|
|
||||||
self.setup_handlers()
|
|
||||||
|
|
||||||
def setup_handlers(self) -> None:
|
|
||||||
r"""Sets up the event handlers for Slack events, such as `app_mention`
|
|
||||||
and `message`.
|
|
||||||
|
|
||||||
This method registers the `app_mention` and `on_message` event handlers
|
|
||||||
with the Slack Bolt app to respond to Slack events.
|
|
||||||
"""
|
|
||||||
self._app.event("app_mention")(self.app_mention)
|
|
||||||
self._app.event("message")(self.on_message)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
port: int = 3000,
|
|
||||||
path: str = "/slack/events",
|
|
||||||
host: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Starts the Slack Bolt app server to listen for incoming Slack
|
|
||||||
events.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port (int): The port on which the server should run (default is
|
|
||||||
3000).
|
|
||||||
path (str): The endpoint path for receiving Slack events (default
|
|
||||||
is "/slack/events").
|
|
||||||
host (Optional[str]): The hostname to bind the server (default is
|
|
||||||
None).
|
|
||||||
"""
|
|
||||||
self._app.start(port=port, path=path, host=host)
|
|
||||||
|
|
||||||
async def handle_request(
|
|
||||||
self, request: requests.Request
|
|
||||||
) -> responses.Response:
|
|
||||||
r"""Handles incoming requests from Slack through the request handler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Request): A Starlette request object representing the
|
|
||||||
incoming request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response generated by the Slack Bolt handler.
|
|
||||||
"""
|
|
||||||
return await self._handler.handle(request)
|
|
||||||
|
|
||||||
async def app_mention(
|
|
||||||
self,
|
|
||||||
context: "AsyncBoltContext",
|
|
||||||
client: "AsyncWebClient",
|
|
||||||
event: Dict[str, Any],
|
|
||||||
body: Dict[str, Any],
|
|
||||||
say: "AsyncSay",
|
|
||||||
) -> None:
|
|
||||||
r"""Event handler for `app_mention` events.
|
|
||||||
|
|
||||||
This method is triggered when someone mentions the app in Slack.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
|
||||||
client (AsyncWebClient): The Slack Web API client.
|
|
||||||
event (Dict[str, Any]): The event data for the app mention.
|
|
||||||
body (Dict[str, Any]): The full request body from Slack.
|
|
||||||
say (AsyncSay): A function to send a response back to the channel.
|
|
||||||
"""
|
|
||||||
event_profile = SlackAppMentionEventProfile(**event)
|
|
||||||
event_body = SlackAppMentionEventBody(**body)
|
|
||||||
|
|
||||||
logger.info(f"app_mention, context: {context}")
|
|
||||||
logger.info(f"app_mention, client: {client}")
|
|
||||||
logger.info(f"app_mention, event_profile: {event_profile}")
|
|
||||||
logger.info(f"app_mention, event_body: {event_body}")
|
|
||||||
logger.info(f"app_mention, say: {say}")
|
|
||||||
|
|
||||||
async def on_message(
|
|
||||||
self,
|
|
||||||
context: "AsyncBoltContext",
|
|
||||||
client: "AsyncWebClient",
|
|
||||||
event: Dict[str, Any],
|
|
||||||
body: Dict[str, Any],
|
|
||||||
say: "AsyncSay",
|
|
||||||
) -> None:
|
|
||||||
r"""Event handler for `message` events.
|
|
||||||
|
|
||||||
This method is triggered when the app receives a message in Slack.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
|
||||||
client (AsyncWebClient): The Slack Web API client.
|
|
||||||
event (Dict[str, Any]): The event data for the message.
|
|
||||||
body (Dict[str, Any]): The full request body from Slack.
|
|
||||||
say (AsyncSay): A function to send a response back to the channel.
|
|
||||||
"""
|
|
||||||
await context.ack()
|
|
||||||
|
|
||||||
event_profile = SlackEventProfile(**event)
|
|
||||||
event_body = SlackEventBody(**body)
|
|
||||||
|
|
||||||
logger.info(f"on_message, context: {context}")
|
|
||||||
logger.info(f"on_message, client: {client}")
|
|
||||||
logger.info(f"on_message, event_profile: {event_profile}")
|
|
||||||
logger.info(f"on_message, event_body: {event_body}")
|
|
||||||
logger.info(f"on_message, say: {say}")
|
|
||||||
|
|
||||||
logger.info(f"Received message: {event_profile.text}")
|
|
||||||
|
|
||||||
def mention_me(
|
|
||||||
self, context: "AsyncBoltContext", body: SlackEventBody
|
|
||||||
) -> bool:
|
|
||||||
r"""Check if the bot is mentioned in the message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
|
||||||
body (SlackEventBody): The body of the Slack event.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the bot is mentioned in the message, False otherwise.
|
|
||||||
"""
|
|
||||||
message = body.event.text
|
|
||||||
bot_user_id = context.bot_user_id
|
|
||||||
mention = f"<@{bot_user_id}>"
|
|
||||||
return mention in message
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
from camel.agents import ChatAgent
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.utils import dependencies_required
|
|
||||||
|
|
||||||
# Conditionally import telebot types only for type checking
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from telebot.types import ( # type: ignore[import-untyped]
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramBot:
|
|
||||||
r"""Represents a Telegram bot that is powered by an agent.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
chat_agent (ChatAgent): Chat agent that will power the bot.
|
|
||||||
telegram_token (str, optional): The bot token.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dependencies_required('telebot')
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
chat_agent: ChatAgent,
|
|
||||||
telegram_token: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.chat_agent = chat_agent
|
|
||||||
|
|
||||||
if not telegram_token:
|
|
||||||
self.token = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
if not self.token:
|
|
||||||
raise ValueError(
|
|
||||||
"`TELEGRAM_TOKEN` not found in environment variables. "
|
|
||||||
"Get it from t.me/BotFather."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.token = telegram_token
|
|
||||||
|
|
||||||
import telebot # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
self.bot = telebot.TeleBot(token=self.token)
|
|
||||||
|
|
||||||
# Register the message handler within the constructor
|
|
||||||
self.bot.message_handler(func=lambda message: True)(self.on_message)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
r"""Start the Telegram bot."""
|
|
||||||
print("Telegram bot is running...")
|
|
||||||
self.bot.infinity_polling()
|
|
||||||
|
|
||||||
def on_message(self, message: 'Message') -> None:
|
|
||||||
r"""Handles incoming messages from the user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (types.Message): The incoming message object.
|
|
||||||
"""
|
|
||||||
self.chat_agent.reset()
|
|
||||||
|
|
||||||
if not message.text:
|
|
||||||
return
|
|
||||||
|
|
||||||
user_msg = BaseMessage.make_user_message(
|
|
||||||
role_name="User", content=message.text
|
|
||||||
)
|
|
||||||
assistant_response = self.chat_agent.step(user_msg)
|
|
||||||
|
|
||||||
self.bot.reply_to(message, assistant_response.msg.content)
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
|
|
||||||
from .base_config import BaseConfig
|
|
||||||
from .cohere_config import COHERE_API_PARAMS, CohereConfig
|
|
||||||
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
|
||||||
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
|
|
||||||
from .groq_config import GROQ_API_PARAMS, GroqConfig
|
|
||||||
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
|
|
||||||
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
|
|
||||||
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
|
|
||||||
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
|
|
||||||
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
|
|
||||||
from .qwen_config import QWEN_API_PARAMS, QwenConfig
|
|
||||||
from .reka_config import REKA_API_PARAMS, RekaConfig
|
|
||||||
from .samba_config import (
|
|
||||||
SAMBA_CLOUD_API_PARAMS,
|
|
||||||
SAMBA_VERSE_API_PARAMS,
|
|
||||||
SambaCloudAPIConfig,
|
|
||||||
SambaVerseAPIConfig,
|
|
||||||
)
|
|
||||||
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
|
|
||||||
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
|
|
||||||
from .yi_config import YI_API_PARAMS, YiConfig
|
|
||||||
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseConfig',
|
|
||||||
'ChatGPTConfig',
|
|
||||||
'OPENAI_API_PARAMS',
|
|
||||||
'AnthropicConfig',
|
|
||||||
'ANTHROPIC_API_PARAMS',
|
|
||||||
'GROQ_API_PARAMS',
|
|
||||||
'GroqConfig',
|
|
||||||
'LiteLLMConfig',
|
|
||||||
'LITELLM_API_PARAMS',
|
|
||||||
'NvidiaConfig',
|
|
||||||
'NVIDIA_API_PARAMS',
|
|
||||||
'OllamaConfig',
|
|
||||||
'OLLAMA_API_PARAMS',
|
|
||||||
'ZhipuAIConfig',
|
|
||||||
'ZHIPUAI_API_PARAMS',
|
|
||||||
'GeminiConfig',
|
|
||||||
'Gemini_API_PARAMS',
|
|
||||||
'VLLMConfig',
|
|
||||||
'VLLM_API_PARAMS',
|
|
||||||
'MistralConfig',
|
|
||||||
'MISTRAL_API_PARAMS',
|
|
||||||
'RekaConfig',
|
|
||||||
'REKA_API_PARAMS',
|
|
||||||
'SambaVerseAPIConfig',
|
|
||||||
'SAMBA_VERSE_API_PARAMS',
|
|
||||||
'SambaCloudAPIConfig',
|
|
||||||
'SAMBA_CLOUD_API_PARAMS',
|
|
||||||
'TogetherAIConfig',
|
|
||||||
'TOGETHERAI_API_PARAMS',
|
|
||||||
'CohereConfig',
|
|
||||||
'COHERE_API_PARAMS',
|
|
||||||
'YiConfig',
|
|
||||||
'YI_API_PARAMS',
|
|
||||||
'QwenConfig',
|
|
||||||
'QWEN_API_PARAMS',
|
|
||||||
'DeepSeekConfig',
|
|
||||||
'DEEPSEEK_API_PARAMS',
|
|
||||||
]
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Anthropic API.
|
|
||||||
|
|
||||||
See: https://docs.anthropic.com/claude/reference/complete_post
|
|
||||||
Args:
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to
|
|
||||||
generate before stopping. Note that Anthropic models may stop
|
|
||||||
before reaching this maximum. This parameter only specifies the
|
|
||||||
absolute maximum number of tokens to generate.
|
|
||||||
(default: :obj:`256`)
|
|
||||||
stop_sequences (List[str], optional): Sequences that will cause the
|
|
||||||
model to stop generating completion text. Anthropic models stop
|
|
||||||
on "\n\nHuman:", and may include additional built-in stop sequences
|
|
||||||
in the future. By providing the stop_sequences parameter, you may
|
|
||||||
include additional strings that will cause the model to stop
|
|
||||||
generating.
|
|
||||||
temperature (float, optional): Amount of randomness injected into the
|
|
||||||
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
|
|
||||||
for analytical / multiple choice, and closer to 1 for creative
|
|
||||||
and generative tasks.
|
|
||||||
(default: :obj:`1`)
|
|
||||||
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
|
|
||||||
compute the cumulative distribution over all the options for each
|
|
||||||
subsequent token in decreasing probability order and cut it off
|
|
||||||
once it reaches a particular probability specified by `top_p`.
|
|
||||||
You should either alter `temperature` or `top_p`,
|
|
||||||
but not both.
|
|
||||||
(default: :obj:`0.7`)
|
|
||||||
top_k (int, optional): Only sample from the top K options for each
|
|
||||||
subsequent token. Used to remove "long tail" low probability
|
|
||||||
responses.
|
|
||||||
(default: :obj:`5`)
|
|
||||||
metadata: An object describing metadata about the request.
|
|
||||||
stream (bool, optional): Whether to incrementally stream the response
|
|
||||||
using server-sent events. (default: :obj:`False`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: int = 256
|
|
||||||
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
|
|
||||||
temperature: float = 1
|
|
||||||
top_p: Union[float, NotGiven] = NOT_GIVEN
|
|
||||||
top_k: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
metadata: NotGiven = NOT_GIVEN
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(ABC, BaseModel):
|
|
||||||
r"""Base configuration class for all models.
|
|
||||||
|
|
||||||
This class provides a common interface for all models, ensuring that all
|
|
||||||
models have a consistent set of attributes and methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
arbitrary_types_allowed=True,
|
|
||||||
extra="forbid",
|
|
||||||
frozen=True,
|
|
||||||
# UserWarning: conflict with protected namespace "model_"
|
|
||||||
protected_namespaces=(),
|
|
||||||
)
|
|
||||||
|
|
||||||
tools: Optional[List[Any]] = None
|
|
||||||
"""A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@field_validator("tools", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def fields_type_checking(cls, tools):
|
|
||||||
r"""Validate the type of tools in the configuration.
|
|
||||||
|
|
||||||
This method ensures that the tools provided in the configuration are
|
|
||||||
instances of `FunctionTool`. If any tool is not an instance of
|
|
||||||
`FunctionTool`, it raises a ValueError.
|
|
||||||
"""
|
|
||||||
if tools is not None:
|
|
||||||
from camel.toolkits import FunctionTool
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
if not isinstance(tool, FunctionTool):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {tool} should "
|
|
||||||
"be an instance of `FunctionTool`."
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
r"""Convert the current configuration to a dictionary.
|
|
||||||
|
|
||||||
This method converts the current configuration object to a dictionary
|
|
||||||
representation, which can be used for serialization or other purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: A dictionary representation of the current
|
|
||||||
configuration.
|
|
||||||
"""
|
|
||||||
config_dict = self.model_dump()
|
|
||||||
|
|
||||||
tools_schema = None
|
|
||||||
if self.tools:
|
|
||||||
from camel.toolkits import FunctionTool
|
|
||||||
|
|
||||||
tools_schema = []
|
|
||||||
for tool in self.tools:
|
|
||||||
if not isinstance(tool, FunctionTool):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {tool} should "
|
|
||||||
"be an instance of `FunctionTool`."
|
|
||||||
)
|
|
||||||
tools_schema.append(tool.get_openai_tool_schema())
|
|
||||||
config_dict["tools"] = tools_schema
|
|
||||||
return config_dict
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CohereConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Cohere API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.3`)
|
|
||||||
documents (list, optional): A list of relevant documents that the
|
|
||||||
model can cite to generate a more accurate reply. Each document is
|
|
||||||
either a string or document object with content and metadata.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens the model
|
|
||||||
will generate as part of the response. (default: :obj:`None`)
|
|
||||||
stop_sequences (List(str), optional): A list of up to 5 strings that
|
|
||||||
the model will use to stop generation. If the model generates a
|
|
||||||
string that matches any of the strings in the list, it will stop
|
|
||||||
generating tokens and return the generated text up to that point
|
|
||||||
not including the stop sequence. (default: :obj:`None`)
|
|
||||||
seed (int, optional): If specified, the backend will make a best
|
|
||||||
effort to sample tokens deterministically, such that repeated
|
|
||||||
requests with the same seed and parameters should return the same
|
|
||||||
result. However, determinism cannot be totally guaranteed.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
frequency_penalty (float, optional): Min value of `0.0`, max value of
|
|
||||||
`1.0`. Used to reduce repetitiveness of generated tokens. The
|
|
||||||
higher the value, the stronger a penalty is applied to previously
|
|
||||||
present tokens, proportional to how many times they have already
|
|
||||||
appeared in the prompt or prior generation. (default: :obj:`0.0`)
|
|
||||||
presence_penalty (float, optional): Min value of `0.0`, max value of
|
|
||||||
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
|
|
||||||
to `frequency_penalty`, except that this penalty is applied
|
|
||||||
equally to all tokens that have already appeared, regardless of
|
|
||||||
their exact frequencies. (default: :obj:`0.0`)
|
|
||||||
k (int, optional): Ensures only the top k most likely tokens are
|
|
||||||
considered for generation at each step. Min value of `0`, max
|
|
||||||
value of `500`. (default: :obj:`0`)
|
|
||||||
p (float, optional): Ensures that only the most likely tokens, with
|
|
||||||
total probability mass of `p`, are considered for generation at
|
|
||||||
each step. If both k and p are enabled, `p` acts after `k`. Min
|
|
||||||
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[float] = 0.2
|
|
||||||
documents: Optional[list] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
frequency_penalty: Optional[float] = 0.0
|
|
||||||
presence_penalty: Optional[float] = 0.0
|
|
||||||
k: Optional[int] = 0
|
|
||||||
p: Optional[float] = 0.75
|
|
||||||
|
|
||||||
|
|
||||||
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Optional, Sequence, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
DeepSeek API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): Controls the diversity and focus of the
|
|
||||||
generated results. Higher values make the output more diverse,
|
|
||||||
while lower values make it more focused. (default: :obj:`1.0`)
|
|
||||||
response_format (object, optional): Specifies the format of the
|
|
||||||
returned content. The available values are `{"type": "text"}` or
|
|
||||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
|
||||||
will output a standard JSON string.
|
|
||||||
(default: :obj:`{"type": "text"}`)
|
|
||||||
stream (bool, optional): If set, partial message deltas will be sent.
|
|
||||||
Tokens will be sent as data-only server-sent events (SSE) as
|
|
||||||
they become available, with the stream terminated by a
|
|
||||||
data: [DONE] message. (default: :obj:`False`)
|
|
||||||
stop (Union[str, list[str]], optional): Up to 16 sequences where
|
|
||||||
the API will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens that can
|
|
||||||
be generated in the chat completion. The total length of input
|
|
||||||
tokens and generated tokens is limited by the model's context
|
|
||||||
length. (default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between -2.0 and 2.0.
|
|
||||||
Positive values penalize new tokens based on whether they
|
|
||||||
appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between -2.0 and 2.0.
|
|
||||||
Positive values penalize new tokens based on their existing
|
|
||||||
frequency in the text so far, decreasing the model's likelihood
|
|
||||||
to repeat the same line verbatim. (default: :obj:`0`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use
|
|
||||||
this to provide a list of functions the model may generate JSON
|
|
||||||
inputs for. A max of 128 functions are supported.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which
|
|
||||||
(if any) tool is called by the model. "none" means the model
|
|
||||||
will not call any tool and instead generates a message. "auto"
|
|
||||||
means the model can pick between generating a message or calling
|
|
||||||
one or more tools. "required" means the model must call one or
|
|
||||||
more tools. Specifying a particular tool via
|
|
||||||
{"type": "function", "function": {"name": "my_function"}} forces
|
|
||||||
the model to call that tool. "none" is the default when no tools
|
|
||||||
are present. "auto" is the default if tools are present.
|
|
||||||
(default: :obj:`"auto"`)
|
|
||||||
logprobs (bool, optional): Whether to return log probabilities of
|
|
||||||
the output tokens or not. If true, returns the log probabilities
|
|
||||||
of each output token returned in the content of message.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
top_logprobs (int, optional): An integer between 0 and 20 specifying
|
|
||||||
the number of most likely tokens to return at each token
|
|
||||||
position, each with an associated log probability. logprobs
|
|
||||||
must be set to true if this parameter is used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
include_usage (bool, optional): When streaming, specifies whether to
|
|
||||||
include usage information in `stream_options`. (default:
|
|
||||||
:obj:`True`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # deepseek default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
logprobs: bool = False
|
|
||||||
top_logprobs: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(self, include_usage: bool = True, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Only set stream_options when stream is True
|
|
||||||
# Otherwise, it will raise error when calling the API
|
|
||||||
if self.stream:
|
|
||||||
self.stream_options = {"include_usage": include_usage}
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
r"""Convert the current configuration to a dictionary.
|
|
||||||
|
|
||||||
This method converts the current configuration object to a dictionary
|
|
||||||
representation, which can be used for serialization or other purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: A dictionary representation of the current
|
|
||||||
configuration.
|
|
||||||
"""
|
|
||||||
config_dict = self.model_dump()
|
|
||||||
if self.tools:
|
|
||||||
from camel.toolkits import FunctionTool
|
|
||||||
|
|
||||||
tools_schema = []
|
|
||||||
for tool in self.tools:
|
|
||||||
if not isinstance(tool, FunctionTool):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {tool} should "
|
|
||||||
"be an instance of `FunctionTool`."
|
|
||||||
)
|
|
||||||
tools_schema.append(tool.get_openai_tool_schema())
|
|
||||||
config_dict["tools"] = NOT_GIVEN
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Optional, Sequence, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Gemini API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` means the
|
|
||||||
model must call one or more tools. Specifying a particular tool
|
|
||||||
via {"type": "function", "function": {"name": "my_function"}}
|
|
||||||
forces the model to call that tool. :obj:`"none"` is the default
|
|
||||||
when no tools are present. :obj:`"auto"` is the default if tools
|
|
||||||
are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
r"""Convert the current configuration to a dictionary.
|
|
||||||
|
|
||||||
This method converts the current configuration object to a dictionary
|
|
||||||
representation, which can be used for serialization or other purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: A dictionary representation of the current
|
|
||||||
configuration.
|
|
||||||
"""
|
|
||||||
config_dict = self.model_dump()
|
|
||||||
if self.tools:
|
|
||||||
from camel.toolkits import FunctionTool
|
|
||||||
|
|
||||||
tools_schema = []
|
|
||||||
for tool in self.tools:
|
|
||||||
if not isinstance(tool, FunctionTool):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {tool} should "
|
|
||||||
"be an instance of `FunctionTool`."
|
|
||||||
)
|
|
||||||
tools_schema.append(tool.get_openai_tool_schema())
|
|
||||||
config_dict["tools"] = NOT_GIVEN
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class GroqConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using OpenAI
|
|
||||||
compatibility.
|
|
||||||
|
|
||||||
Reference: https://console.groq.com/docs/openai
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
user (str, optional): A unique identifier representing your end-user,
|
|
||||||
which can help OpenAI to monitor and detect abuse.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` means the
|
|
||||||
model must call one or more tools. Specifying a particular tool
|
|
||||||
via {"type": "function", "function": {"name": "my_function"}}
|
|
||||||
forces the model to call that tool. :obj:`"none"` is the default
|
|
||||||
when no tools are present. :obj:`"auto"` is the default if tools
|
|
||||||
are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
user: str = ""
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = "auto"
|
|
||||||
|
|
||||||
|
|
||||||
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
LiteLLM API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout (Optional[Union[float, str]], optional): Request timeout.
|
|
||||||
(default: None)
|
|
||||||
temperature (Optional[float], optional): Temperature parameter for
|
|
||||||
controlling randomness. (default: None)
|
|
||||||
top_p (Optional[float], optional): Top-p parameter for nucleus
|
|
||||||
sampling. (default: None)
|
|
||||||
n (Optional[int], optional): Number of completions to generate.
|
|
||||||
(default: None)
|
|
||||||
stream (Optional[bool], optional): Whether to return a streaming
|
|
||||||
response. (default: None)
|
|
||||||
stream_options (Optional[dict], optional): Options for the streaming
|
|
||||||
response. (default: None)
|
|
||||||
stop (Optional[Union[str, List[str]]], optional): Sequences where the
|
|
||||||
API will stop generating further tokens. (default: None)
|
|
||||||
max_tokens (Optional[int], optional): Maximum number of tokens to
|
|
||||||
generate. (default: None)
|
|
||||||
presence_penalty (Optional[float], optional): Penalize new tokens
|
|
||||||
based on their existence in the text so far. (default: None)
|
|
||||||
frequency_penalty (Optional[float], optional): Penalize new tokens
|
|
||||||
based on their frequency in the text so far. (default: None)
|
|
||||||
logit_bias (Optional[dict], optional): Modify the probability of
|
|
||||||
specific tokens appearing in the completion. (default: None)
|
|
||||||
user (Optional[str], optional): A unique identifier representing the
|
|
||||||
end-user. (default: None)
|
|
||||||
response_format (Optional[dict], optional): Response format
|
|
||||||
parameters. (default: None)
|
|
||||||
seed (Optional[int], optional): Random seed. (default: None)
|
|
||||||
tools (Optional[List], optional): List of tools. (default: None)
|
|
||||||
tool_choice (Optional[Union[str, dict]], optional): Tool choice
|
|
||||||
parameters. (default: None)
|
|
||||||
logprobs (Optional[bool], optional): Whether to return log
|
|
||||||
probabilities of the output tokens. (default: None)
|
|
||||||
top_logprobs (Optional[int], optional): Number of most likely tokens
|
|
||||||
to return at each token position. (default: None)
|
|
||||||
deployment_id (Optional[str], optional): Deployment ID. (default: None)
|
|
||||||
extra_headers (Optional[dict], optional): Additional headers for the
|
|
||||||
request. (default: None)
|
|
||||||
api_version (Optional[str], optional): API version. (default: None)
|
|
||||||
mock_response (Optional[str], optional): Mock completion response for
|
|
||||||
testing or debugging. (default: None)
|
|
||||||
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
|
|
||||||
provider. (default: None)
|
|
||||||
max_retries (Optional[int], optional): Maximum number of retries.
|
|
||||||
(default: None)
|
|
||||||
"""
|
|
||||||
|
|
||||||
timeout: Optional[Union[float, str]] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
n: Optional[int] = None
|
|
||||||
stream: Optional[bool] = None
|
|
||||||
stream_options: Optional[dict] = None
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
presence_penalty: Optional[float] = None
|
|
||||||
frequency_penalty: Optional[float] = None
|
|
||||||
logit_bias: Optional[dict] = None
|
|
||||||
user: Optional[str] = None
|
|
||||||
response_format: Optional[dict] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
tool_choice: Optional[Union[str, dict]] = None
|
|
||||||
logprobs: Optional[bool] = None
|
|
||||||
top_logprobs: Optional[int] = None
|
|
||||||
deployment_id: Optional[str] = None
|
|
||||||
extra_headers: Optional[dict] = None
|
|
||||||
api_version: Optional[str] = None
|
|
||||||
mock_response: Optional[str] = None
|
|
||||||
custom_llm_provider: Optional[str] = None
|
|
||||||
max_retries: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import field_validator
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Mistral API.
|
|
||||||
|
|
||||||
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
|
|
||||||
|
|
||||||
#TODO: Support stream mode
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (Optional[float], optional): temperature the temperature
|
|
||||||
to use for sampling, e.g. 0.5.
|
|
||||||
top_p (Optional[float], optional): the cumulative probability of
|
|
||||||
tokens to generate, e.g. 0.9. Defaults to None.
|
|
||||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
|
||||||
generate, e.g. 100. Defaults to None.
|
|
||||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
|
||||||
is detected. Or if one of these tokens is detected when providing
|
|
||||||
a string list.
|
|
||||||
random_seed (Optional[int], optional): the random seed to use for
|
|
||||||
sampling, e.g. 42. Defaults to None.
|
|
||||||
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
|
|
||||||
Defaults to False.
|
|
||||||
response_format (Union[Dict[str, str], ResponseFormat): format of the
|
|
||||||
response.
|
|
||||||
tool_choice (str, optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"any"` means the
|
|
||||||
model must call one or more tools. :obj:`"auto"` is the default
|
|
||||||
value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, list[str]]] = None
|
|
||||||
random_seed: Optional[int] = None
|
|
||||||
safe_prompt: bool = False
|
|
||||||
response_format: Optional[Union[Dict[str, str], Any]] = None
|
|
||||||
tool_choice: Optional[str] = "auto"
|
|
||||||
|
|
||||||
@field_validator("response_format", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def fields_type_checking(cls, response_format):
|
|
||||||
if response_format and not isinstance(response_format, dict):
|
|
||||||
from mistralai.models import ResponseFormat
|
|
||||||
|
|
||||||
if not isinstance(response_format, ResponseFormat):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {response_format} should be an instance "
|
|
||||||
"of `mistralai.models.ResponseFormat`."
|
|
||||||
)
|
|
||||||
return response_format
|
|
||||||
|
|
||||||
|
|
||||||
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaConfig(BaseConfig):
|
|
||||||
r"""Configuration class for NVIDIA API models.
|
|
||||||
|
|
||||||
This class defines the configuration parameters for NVIDIA's language
|
|
||||||
models, including temperature, sampling parameters, and response format
|
|
||||||
settings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream (bool, optional): Whether to stream the response.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
temperature (float, optional): Controls randomness in the response.
|
|
||||||
Higher values make output more random, lower values make it more
|
|
||||||
deterministic. Range: [0.0, 2.0]. (default: :obj:`0.7`)
|
|
||||||
top_p (float, optional): Controls diversity via nucleus sampling.
|
|
||||||
Range: [0.0, 1.0]. (default: :obj:`0.95`)
|
|
||||||
presence_penalty (float, optional): Penalizes new tokens based on
|
|
||||||
whether they appear in the text so far. Range: [-2.0, 2.0].
|
|
||||||
(default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Penalizes new tokens based on
|
|
||||||
their frequency in the text so far. Range: [-2.0, 2.0].
|
|
||||||
(default: :obj:`0.0`)
|
|
||||||
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
|
|
||||||
to generate. If not provided, model will use its default maximum.
|
|
||||||
(default: :obj:`NOT_GIVEN`)
|
|
||||||
seed (Optional[int], optional): Random seed for deterministic sampling.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tools (Optional[List[Dict]], optional): List of tools available to the
|
|
||||||
model. This includes tools such as a text editor, a calculator, or
|
|
||||||
a search engine. (default: :obj:`None`)
|
|
||||||
tool_choice (Optional[str], optional): Tool choice configuration.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
stop (Optional[List[str]], optional): List of stop sequences.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
stream: bool = Field(default=False)
|
|
||||||
temperature: float = Field(default=0.7)
|
|
||||||
top_p: float = Field(default=0.95)
|
|
||||||
presence_penalty: float = Field(default=0.0)
|
|
||||||
frequency_penalty: float = Field(default=0.0)
|
|
||||||
max_tokens: Union[int, NotGiven] = Field(default=NOT_GIVEN)
|
|
||||||
seed: Optional[int] = Field(default=None)
|
|
||||||
tool_choice: Optional[str] = Field(default=None)
|
|
||||||
stop: Optional[List[str]] = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using OpenAI
|
|
||||||
compatibility
|
|
||||||
|
|
||||||
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2
|
|
||||||
top_p: float = 1.0
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Optional, Sequence, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
OpenAI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
|
||||||
appearing in the completion. Accepts a json object that maps tokens
|
|
||||||
(specified by their token ID in the tokenizer) to an associated
|
|
||||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
|
||||||
is added to the logits generated by the model prior to sampling.
|
|
||||||
The exact effect will vary per model, but values between:obj:` -1`
|
|
||||||
and :obj:`1` should decrease or increase likelihood of selection;
|
|
||||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
|
||||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
|
||||||
user (str, optional): A unique identifier representing your end-user,
|
|
||||||
which can help OpenAI to monitor and detect abuse.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` means the
|
|
||||||
model must call one or more tools. Specifying a particular tool
|
|
||||||
via {"type": "function", "function": {"name": "my_function"}}
|
|
||||||
forces the model to call that tool. :obj:`"none"` is the default
|
|
||||||
when no tools are present. :obj:`"auto"` is the default if tools
|
|
||||||
are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: dict = Field(default_factory=dict)
|
|
||||||
user: str = ""
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
r"""Convert the current configuration to a dictionary.
|
|
||||||
|
|
||||||
This method converts the current configuration object to a dictionary
|
|
||||||
representation, which can be used for serialization or other purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: A dictionary representation of the current
|
|
||||||
configuration.
|
|
||||||
"""
|
|
||||||
config_dict = self.model_dump()
|
|
||||||
if self.tools:
|
|
||||||
from camel.toolkits import FunctionTool
|
|
||||||
|
|
||||||
tools_schema = []
|
|
||||||
for tool in self.tools:
|
|
||||||
if not isinstance(tool, FunctionTool):
|
|
||||||
raise ValueError(
|
|
||||||
f"The tool {tool} should "
|
|
||||||
"be an instance of `FunctionTool`."
|
|
||||||
)
|
|
||||||
tools_schema.append(tool.get_openai_tool_schema())
|
|
||||||
config_dict["tools"] = NOT_GIVEN
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import ClassVar, Optional, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class QwenConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Qwen API. You can refer to the following link for more details:
|
|
||||||
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream (bool, optional): Whether to stream the response.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
temperature (float, optional): Controls the diversity and focus of
|
|
||||||
the generated results. Lower values make the output more focused,
|
|
||||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
|
||||||
top_p (float, optional): Controls the diversity and focus of the
|
|
||||||
generated results. Higher values make the output more diverse,
|
|
||||||
while lower values make it more focused. (default: :obj:`0.9`)
|
|
||||||
presence_penalty (float, optional): Controls the repetition of
|
|
||||||
content in the generated results. Positive values reduce the
|
|
||||||
repetition of content, while negative values increase it.
|
|
||||||
(default: :obj:`0.0`)
|
|
||||||
response_format (object, optional): Specifies the format of the
|
|
||||||
returned content. The available values are `{"type": "text"}` or
|
|
||||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
|
||||||
will output a standard JSON string.
|
|
||||||
(default: :obj:`{"type": "text"}`)
|
|
||||||
max_tokens (Union[int, NotGiven], optional): Allows the model to
|
|
||||||
generate the maximum number of tokens.
|
|
||||||
(default: :obj:`NOT_GIVEN`)
|
|
||||||
seed (int, optional): Sets the seed parameter to make the text
|
|
||||||
generation process more deterministic, typically used to ensure
|
|
||||||
that the results are consistent across model runs. By passing the
|
|
||||||
same seed value (specified by you) in each model call while
|
|
||||||
keeping other parameters unchanged, the model is likely to return
|
|
||||||
the same result.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
stop (str or list, optional): Using the stop parameter, the model will
|
|
||||||
automatically stop generating text when it is about to include the
|
|
||||||
specified string or token_id. You can use the stop parameter to
|
|
||||||
control the output of the model by passing sensitive words.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tools (list, optional): Specifies an array of tools that the model can
|
|
||||||
call. It can contain one or more tool objects. During a function
|
|
||||||
call process, the model will select one tool from the array.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
extra_body (dict, optional): Additional parameters to be sent to the
|
|
||||||
Qwen API. If you want to enable internet search, you can set this
|
|
||||||
parameter to `{"enable_search": True}`.
|
|
||||||
(default: :obj:`{"enable_search": False}`)
|
|
||||||
include_usage (bool, optional): When streaming, specifies whether to
|
|
||||||
include usage information in `stream_options`. (default:
|
|
||||||
:obj:`True`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
stream: bool = False
|
|
||||||
temperature: float = 0.3
|
|
||||||
top_p: float = 0.9
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: ClassVar[dict] = {"type": "text"}
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
seed: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, list]] = None
|
|
||||||
extra_body: ClassVar[dict] = {"enable_search": False}
|
|
||||||
|
|
||||||
def __init__(self, include_usage: bool = True, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Only set stream_options when stream is True
|
|
||||||
# Otherwise, it will raise error when calling the API
|
|
||||||
if self.stream:
|
|
||||||
self.stream_options = {"include_usage": include_usage}
|
|
||||||
|
|
||||||
|
|
||||||
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Optional, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class RekaConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Reka API.
|
|
||||||
|
|
||||||
Reference: https://docs.reka.ai/api-reference/chat/create
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (Optional[float], optional): temperature the temperature
|
|
||||||
to use for sampling, e.g. 0.5.
|
|
||||||
top_p (Optional[float], optional): the cumulative probability of
|
|
||||||
tokens to generate, e.g. 0.9. Defaults to None.
|
|
||||||
top_k (Optional[int], optional): Parameter which forces the model to
|
|
||||||
only consider the tokens with the `top_k` highest probabilities at
|
|
||||||
the next step. Defaults to 1024.
|
|
||||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
|
||||||
generate, e.g. 100. Defaults to None.
|
|
||||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
|
||||||
is detected. Or if one of these tokens is detected when providing
|
|
||||||
a string list.
|
|
||||||
seed (Optional[int], optional): the random seed to use for sampling, e.
|
|
||||||
g. 42. Defaults to None.
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
use_search_engine (Optional[bool]): Whether to consider using search
|
|
||||||
engine to complete the request. Note that even if this is set to
|
|
||||||
`True`, the model might decide to not use search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
stop: Optional[Union[str, list[str]]] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
use_search_engine: Optional[bool] = False
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
config_dict = super().as_dict()
|
|
||||||
if "tools" in config_dict:
|
|
||||||
del config_dict["tools"] # Reka does not support tool calling
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Optional, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class SambaVerseAPIConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
SambaVerse API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.7`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`0.95`)
|
|
||||||
top_k (int, optional): Only sample from the top K options for each
|
|
||||||
subsequent token. Used to remove "long tail" low probability
|
|
||||||
responses.
|
|
||||||
(default: :obj:`50`)
|
|
||||||
max_tokens (Optional[int], optional): The maximum number of tokens to
|
|
||||||
generate, e.g. 100.
|
|
||||||
(default: :obj:`2048`)
|
|
||||||
repetition_penalty (Optional[float], optional): The parameter for
|
|
||||||
repetition penalty. 1.0 means no penalty.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
|
||||||
is detected. Or if one of these tokens is detected when providing
|
|
||||||
a string list.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
stream (Optional[bool]): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
Currently SambaVerse API doesn't support stream mode.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[float] = 0.7
|
|
||||||
top_p: Optional[float] = 0.95
|
|
||||||
top_k: Optional[int] = 50
|
|
||||||
max_tokens: Optional[int] = 2048
|
|
||||||
repetition_penalty: Optional[float] = 1.0
|
|
||||||
stop: Optional[Union[str, list[str]]] = ""
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
config_dict = super().as_dict()
|
|
||||||
if "tools" in config_dict:
|
|
||||||
del config_dict["tools"] # SambaNova does not support tool calling
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
SAMBA_VERSE_API_PARAMS = {
|
|
||||||
param for param in SambaVerseAPIConfig().model_fields.keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SambaCloudAPIConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
OpenAI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
|
||||||
appearing in the completion. Accepts a json object that maps tokens
|
|
||||||
(specified by their token ID in the tokenizer) to an associated
|
|
||||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
|
||||||
is added to the logits generated by the model prior to sampling.
|
|
||||||
The exact effect will vary per model, but values between:obj:` -1`
|
|
||||||
and :obj:`1` should decrease or increase likelihood of selection;
|
|
||||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
|
||||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
|
||||||
user (str, optional): A unique identifier representing your end-user,
|
|
||||||
which can help OpenAI to monitor and detect abuse.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` means the
|
|
||||||
model must call one or more tools. Specifying a particular tool
|
|
||||||
via {"type": "function", "function": {"name": "my_function"}}
|
|
||||||
forces the model to call that tool. :obj:`"none"` is the default
|
|
||||||
when no tools are present. :obj:`"auto"` is the default if tools
|
|
||||||
are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: dict = Field(default_factory=dict)
|
|
||||||
user: str = ""
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
SAMBA_CLOUD_API_PARAMS = {
|
|
||||||
param for param in SambaCloudAPIConfig().model_fields.keys()
|
|
||||||
}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherAIConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
OpenAI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
|
||||||
appearing in the completion. Accepts a json object that maps tokens
|
|
||||||
(specified by their token ID in the tokenizer) to an associated
|
|
||||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
|
||||||
is added to the logits generated by the model prior to sampling.
|
|
||||||
The exact effect will vary per model, but values between:obj:` -1`
|
|
||||||
and :obj:`1` should decrease or increase likelihood of selection;
|
|
||||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
|
||||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
|
||||||
user (str, optional): A unique identifier representing your end-user,
|
|
||||||
which can help OpenAI to monitor and detect abuse.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: dict = Field(default_factory=dict)
|
|
||||||
user: str = ""
|
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
config_dict = super().as_dict()
|
|
||||||
if "tools" in config_dict:
|
|
||||||
del config_dict["tools"] # Currently does not support tool calling
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
TOGETHERAI_API_PARAMS = {
|
|
||||||
param for param in TogetherAIConfig.model_fields.keys()
|
|
||||||
}
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
# flake8: noqa: E501
|
|
||||||
class VLLMConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
OpenAI API.
|
|
||||||
|
|
||||||
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`1.0`)
|
|
||||||
n (int, optional): How many chat completion choices to generate for
|
|
||||||
each input message. (default: :obj:`1`)
|
|
||||||
response_format (object, optional): An object specifying the format
|
|
||||||
that the model must output. Compatible with GPT-4 Turbo and all
|
|
||||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
|
||||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
|
||||||
message the model generates is valid JSON. Important: when using
|
|
||||||
JSON mode, you must also instruct the model to produce JSON
|
|
||||||
yourself via a system or user message. Without this, the model
|
|
||||||
may generate an unending stream of whitespace until the generation
|
|
||||||
reaches the token limit, resulting in a long-running and seemingly
|
|
||||||
"stuck" request. Also note that the message content may be
|
|
||||||
partially cut off if finish_reason="length", which indicates the
|
|
||||||
generation exceeded max_tokens or the conversation exceeded the
|
|
||||||
max context length.
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
|
||||||
they appear in the text so far, increasing the model's likelihood
|
|
||||||
to talk about new topics. See more information about frequency and
|
|
||||||
presence penalties. (default: :obj:`0.0`)
|
|
||||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
|
||||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
|
||||||
existing frequency in the text so far, decreasing the model's
|
|
||||||
likelihood to repeat the same line verbatim. See more information
|
|
||||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
|
||||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
|
||||||
appearing in the completion. Accepts a json object that maps tokens
|
|
||||||
(specified by their token ID in the tokenizer) to an associated
|
|
||||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
|
||||||
is added to the logits generated by the model prior to sampling.
|
|
||||||
The exact effect will vary per model, but values between:obj:` -1`
|
|
||||||
and :obj:`1` should decrease or increase likelihood of selection;
|
|
||||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
|
||||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
|
||||||
user (str, optional): A unique identifier representing your end-user,
|
|
||||||
which can help OpenAI to monitor and detect abuse.
|
|
||||||
(default: :obj:`""`)
|
|
||||||
logprobs: Whether to return log probabilities of the output tokens or
|
|
||||||
not. If true, returns the log probabilities of each output token
|
|
||||||
returned in the `logits` of `message`. (default: :obj:`None`)
|
|
||||||
top_logprobs: An integer between 0 and 20 specifying the number of
|
|
||||||
most likely tokens to return at each token position, each with an
|
|
||||||
associated log probability. `logprobs` must be set to `true` if
|
|
||||||
this parameter is used. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2 # openai default: 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
n: int = 1
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
presence_penalty: float = 0.0
|
|
||||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
|
||||||
frequency_penalty: float = 0.0
|
|
||||||
logit_bias: dict = Field(default_factory=dict)
|
|
||||||
user: str = ""
|
|
||||||
logprobs: Optional[bool] = None
|
|
||||||
top_logprobs: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class YiConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using the
|
|
||||||
Yi API. You can refer to the following link for more details:
|
|
||||||
https://platform.lingyiwanwu.com/docs/api-reference
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` or
|
|
||||||
specifying a particular tool via
|
|
||||||
{"type": "function", "function": {"name": "some_function"}}
|
|
||||||
can be used to guide the model to use tools more strongly.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): Specifies the maximum number of tokens
|
|
||||||
the model can generate. This sets an upper limit, but does not
|
|
||||||
guarantee that this number will always be reached.
|
|
||||||
(default: :obj:`5000`)
|
|
||||||
top_p (float, optional): Controls the randomness of the generated
|
|
||||||
results. Lower values lead to less randomness, while higher
|
|
||||||
values increase randomness. (default: :obj:`0.9`)
|
|
||||||
temperature (float, optional): Controls the diversity and focus of
|
|
||||||
the generated results. Lower values make the output more focused,
|
|
||||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
|
||||||
stream (bool, optional): If True, enables streaming output.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
top_p: float = 0.9
|
|
||||||
temperature: float = 0.3
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
|
||||||
|
|
||||||
from camel.configs.base_config import BaseConfig
|
|
||||||
from camel.types import NOT_GIVEN, NotGiven
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAIConfig(BaseConfig):
|
|
||||||
r"""Defines the parameters for generating chat completions using OpenAI
|
|
||||||
compatibility
|
|
||||||
|
|
||||||
Reference: https://open.bigmodel.cn/dev/api#glm-4v
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): Sampling temperature to use, between
|
|
||||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
|
||||||
while lower values make it more focused and deterministic.
|
|
||||||
(default: :obj:`0.2`)
|
|
||||||
top_p (float, optional): An alternative to sampling with temperature,
|
|
||||||
called nucleus sampling, where the model considers the results of
|
|
||||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
|
||||||
the tokens comprising the top 10% probability mass are considered.
|
|
||||||
(default: :obj:`0.6`)
|
|
||||||
stream (bool, optional): If True, partial message deltas will be sent
|
|
||||||
as data-only server-sent events as they become available.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
|
||||||
will stop generating further tokens. (default: :obj:`None`)
|
|
||||||
max_tokens (int, optional): The maximum number of tokens to generate
|
|
||||||
in the chat completion. The total length of input tokens and
|
|
||||||
generated tokens is limited by the model's context length.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tools (list[FunctionTool], optional): A list of tools the model may
|
|
||||||
call. Currently, only functions are supported as a tool. Use this
|
|
||||||
to provide a list of functions the model may generate JSON inputs
|
|
||||||
for. A max of 128 functions are supported.
|
|
||||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
|
||||||
any) tool is called by the model. :obj:`"none"` means the model
|
|
||||||
will not call any tool and instead generates a message.
|
|
||||||
:obj:`"auto"` means the model can pick between generating a
|
|
||||||
message or calling one or more tools. :obj:`"required"` means the
|
|
||||||
model must call one or more tools. Specifying a particular tool
|
|
||||||
via {"type": "function", "function": {"name": "my_function"}}
|
|
||||||
forces the model to call that tool. :obj:`"none"` is the default
|
|
||||||
when no tools are present. :obj:`"auto"` is the default if tools
|
|
||||||
are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.2
|
|
||||||
top_p: float = 0.6
|
|
||||||
stream: bool = False
|
|
||||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
|
||||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
|
||||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .base import BaseDatasetManager
|
|
||||||
from .huggingface import HuggingFaceDatasetManager
|
|
||||||
from .models import Record
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseDatasetManager",
|
|
||||||
"Record",
|
|
||||||
"HuggingFaceDatasetManager",
|
|
||||||
]
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
from camel.datahubs.models import Record
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDatasetManager(ABC):
|
|
||||||
r"""Abstract base class for dataset managers."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_dataset(self, name: str, **kwargs: Any) -> str:
|
|
||||||
r"""Creates a new dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the dataset.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The URL of the created dataset.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_datasets(
|
|
||||||
self, username: str, limit: int = 100, **kwargs: Any
|
|
||||||
) -> List[str]:
|
|
||||||
r"""Lists all datasets for the current user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username (str): The username of the user whose datasets to list.
|
|
||||||
limit (int): The maximum number of datasets to list.
|
|
||||||
(default::obj:`100`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: A list of dataset ids.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
|
||||||
r"""Deletes a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset to delete.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
records: List[Record],
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Adds records to a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
records (List[Record]): A list of records to add to the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
(default::obj:`"records/records.json"`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
records: List[Record],
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Updates records in a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
records (List[Record]): A list of records to update in the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
(default::obj:`"records/records.json"`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Record]:
|
|
||||||
r"""Lists records in a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
(default::obj:`"records/records.json"`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# New method for record deletion
|
|
||||||
@abstractmethod
|
|
||||||
def delete_record(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
record_id: str,
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Deletes a record from the dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
record_id (str): The ID of the record to delete.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
(default::obj:`"records/records.json"`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@@ -1,433 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from camel.datahubs.base import BaseDatasetManager
|
|
||||||
from camel.datahubs.models import Record
|
|
||||||
from camel.logger import get_logger
|
|
||||||
from camel.types import HuggingFaceRepoType
|
|
||||||
from camel.utils import api_keys_required, dependencies_required
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceDatasetManager(BaseDatasetManager):
|
|
||||||
r"""A dataset manager for Hugging Face datasets. This class provides
|
|
||||||
methods to create, add, update, delete, and list records in a dataset on
|
|
||||||
the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (str): The Hugging Face API token. If not provided, the token
|
|
||||||
will be read from the environment variable `HUGGING_FACE_TOKEN`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_keys_required("HUGGING_FACE_TOKEN")
|
|
||||||
@dependencies_required('huggingface_hub')
|
|
||||||
def __init__(self, token: Optional[str] = None):
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
self._api_key = token or os.getenv("HUGGING_FACE_TOKEN")
|
|
||||||
self.api = HfApi(token=self._api_key)
|
|
||||||
|
|
||||||
def create_dataset_card(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
description: str,
|
|
||||||
license: Optional[str] = None,
|
|
||||||
version: Optional[str] = None,
|
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
authors: Optional[List[str]] = None,
|
|
||||||
size_category: Optional[List[str]] = None,
|
|
||||||
language: Optional[List[str]] = None,
|
|
||||||
task_categories: Optional[List[str]] = None,
|
|
||||||
content: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
|
|
||||||
format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
description (str): A description of the dataset.
|
|
||||||
license (str): The license of the dataset. (default: :obj:`None`)
|
|
||||||
version (str): The version of the dataset. (default: :obj:`None`)
|
|
||||||
tags (list): A list of tags for the dataset.(default: :obj:`None`)
|
|
||||||
authors (list): A list of authors of the dataset. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
size_category (list): A size category for the dataset. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
language (list): A list of languages the dataset is in. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
task_categories (list): A list of task categories. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
content (str): Custom markdown content that the user wants to add
|
|
||||||
to the dataset card. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"license": license,
|
|
||||||
"authors": authors,
|
|
||||||
"task_categories": task_categories,
|
|
||||||
"language": language,
|
|
||||||
"tags": tags,
|
|
||||||
"pretty_name": dataset_name,
|
|
||||||
"size_categories": size_category,
|
|
||||||
"version": version,
|
|
||||||
"description": description,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Remove keys with None values
|
|
||||||
metadata = {k: v for k, v in metadata.items() if v}
|
|
||||||
|
|
||||||
card_content = (
|
|
||||||
"---\n"
|
|
||||||
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
|
|
||||||
+ "\n---"
|
|
||||||
)
|
|
||||||
|
|
||||||
if content:
|
|
||||||
card_content += f"\n\n# Additional Information\n{content}\n"
|
|
||||||
|
|
||||||
self._upload_file(
|
|
||||||
file_content=card_content,
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
filepath="README.md",
|
|
||||||
file_type="md",
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_dataset(
|
|
||||||
self, name: str, private: bool = False, **kwargs: Any
|
|
||||||
) -> str:
|
|
||||||
r"""Creates a new dataset on the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the dataset.
|
|
||||||
private (bool): Whether the dataset should be private. defaults to
|
|
||||||
False.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The URL of the created dataset.
|
|
||||||
"""
|
|
||||||
from huggingface_hub.errors import RepositoryNotFoundError
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.api.repo_info(
|
|
||||||
repo_id=name,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
self.api.create_repo(
|
|
||||||
repo_id=name,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
private=private,
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"https://huggingface.co/datasets/{name}"
|
|
||||||
|
|
||||||
def list_datasets(
|
|
||||||
self, username: str, limit: int = 100, **kwargs: Any
|
|
||||||
) -> List[str]:
|
|
||||||
r"""Lists all datasets for the current user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username (str): The username of the user whose datasets to list.
|
|
||||||
limit (int): The maximum number of datasets to list.
|
|
||||||
(default: :obj:`100`)
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: A list of dataset ids.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return [
|
|
||||||
dataset.id
|
|
||||||
for dataset in self.api.list_datasets(
|
|
||||||
author=username, limit=limit, **kwargs
|
|
||||||
)
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing datasets: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
|
||||||
r"""Deletes a dataset from the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset to delete.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.api.delete_repo(
|
|
||||||
repo_id=dataset_name,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def add_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
records: List[Record],
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Adds records to a dataset on the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
records (List[Record]): A list of records to add to the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the dataset already has a records file.
|
|
||||||
"""
|
|
||||||
existing_records = self._download_records(
|
|
||||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if existing_records:
|
|
||||||
raise ValueError(
|
|
||||||
f"Dataset '{filepath}' already exists. "
|
|
||||||
f"Use `update_records` to modify."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._upload_records(
|
|
||||||
records=records,
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
filepath=filepath,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
records: List[Record],
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Updates records in a dataset on the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
records (List[Record]): A list of records to update in the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the dataset does not have an existing file to update
|
|
||||||
records in.
|
|
||||||
"""
|
|
||||||
existing_records = self._download_records(
|
|
||||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not existing_records:
|
|
||||||
logger.warning(
|
|
||||||
f"Dataset '{dataset_name}' does not have existing "
|
|
||||||
"records. Adding new records."
|
|
||||||
)
|
|
||||||
self._upload_records(
|
|
||||||
records=records,
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
filepath=filepath,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
old_dict = {record.id: record for record in existing_records}
|
|
||||||
new_dict = {record.id: record for record in records}
|
|
||||||
merged_dict = old_dict.copy()
|
|
||||||
merged_dict.update(new_dict)
|
|
||||||
|
|
||||||
self._upload_records(
|
|
||||||
records=list(merged_dict.values()),
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
filepath=filepath,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_record(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
record_id: str,
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
r"""Deletes a record from the dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
record_id (str): The ID of the record to delete.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the dataset does not have an existing file to delete
|
|
||||||
records from.
|
|
||||||
"""
|
|
||||||
existing_records = self._download_records(
|
|
||||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not existing_records:
|
|
||||||
raise ValueError(
|
|
||||||
f"Dataset '{dataset_name}' does not have an existing file to "
|
|
||||||
f"delete records from."
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered_records = [
|
|
||||||
record for record in existing_records if record.id != record_id
|
|
||||||
]
|
|
||||||
|
|
||||||
self._upload_records(
|
|
||||||
records=filtered_records,
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
filepath=filepath,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_records(
|
|
||||||
self,
|
|
||||||
dataset_name: str,
|
|
||||||
filepath: str = "records/records.json",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Record]:
|
|
||||||
r"""Lists all records in a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name (str): The name of the dataset.
|
|
||||||
filepath (str): The path to the file containing the records.
|
|
||||||
kwargs (Any): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Record]: A list of records in the dataset.
|
|
||||||
"""
|
|
||||||
return self._download_records(
|
|
||||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _download_records(
|
|
||||||
self, dataset_name: str, filepath: str, **kwargs: Any
|
|
||||||
) -> List[Record]:
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from huggingface_hub.errors import EntryNotFoundError
|
|
||||||
|
|
||||||
try:
|
|
||||||
downloaded_file_path = hf_hub_download(
|
|
||||||
repo_id=dataset_name,
|
|
||||||
filename=filepath,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
token=self._api_key,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(downloaded_file_path, "r") as f:
|
|
||||||
records_data = json.load(f)
|
|
||||||
|
|
||||||
return [Record(**record) for record in records_data]
|
|
||||||
except EntryNotFoundError:
|
|
||||||
logger.info(f"No records found for dataset '{dataset_name}'.")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error downloading or processing records: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _upload_records(
|
|
||||||
self,
|
|
||||||
records: List[Record],
|
|
||||||
dataset_name: str,
|
|
||||||
filepath: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
delete=False, mode="w", newline="", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
json.dump([record.model_dump() for record in records], f)
|
|
||||||
temp_file_path = f.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.api.upload_file(
|
|
||||||
path_or_fileobj=temp_file_path,
|
|
||||||
path_in_repo=filepath,
|
|
||||||
repo_id=dataset_name,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error uploading records file: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
|
|
||||||
def _upload_file(
|
|
||||||
self,
|
|
||||||
file_content: str,
|
|
||||||
dataset_name: str,
|
|
||||||
filepath: str,
|
|
||||||
file_type: str = "json",
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", delete=False, suffix=f".{file_type}"
|
|
||||||
) as f:
|
|
||||||
if file_type == "json":
|
|
||||||
if isinstance(file_content, str):
|
|
||||||
try:
|
|
||||||
json_content = json.loads(file_content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid JSON string provided for file_content."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
json.dumps(file_content)
|
|
||||||
json_content = file_content
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise ValueError(
|
|
||||||
"file_content is not JSON serializable."
|
|
||||||
)
|
|
||||||
|
|
||||||
json.dump(json_content, f)
|
|
||||||
elif file_type == "md" or file_type == "txt":
|
|
||||||
f.write(file_content)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported file type: {file_type}")
|
|
||||||
|
|
||||||
temp_file_path = f.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.api.upload_file(
|
|
||||||
path_or_fileobj=temp_file_path,
|
|
||||||
path_in_repo=filepath,
|
|
||||||
repo_id=dataset_name,
|
|
||||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
logger.info(f"File uploaded successfully: {filepath}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error uploading file: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class Record(BaseModel):
|
|
||||||
id: Optional[str] = None
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
content: Dict[str, Any]
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .base import BaseEmbedding
|
|
||||||
from .mistral_embedding import MistralEmbedding
|
|
||||||
from .openai_compatible_embedding import OpenAICompatibleEmbedding
|
|
||||||
from .openai_embedding import OpenAIEmbedding
|
|
||||||
from .sentence_transformers_embeddings import SentenceTransformerEncoder
|
|
||||||
from .vlm_embedding import VisionLanguageEmbedding
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseEmbedding",
|
|
||||||
"OpenAIEmbedding",
|
|
||||||
"SentenceTransformerEncoder",
|
|
||||||
"VisionLanguageEmbedding",
|
|
||||||
"MistralEmbedding",
|
|
||||||
"OpenAICompatibleEmbedding",
|
|
||||||
]
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Generic, TypeVar
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedding(ABC, Generic[T]):
|
|
||||||
r"""Abstract base class for text embedding functionalities."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_list(
|
|
||||||
self,
|
|
||||||
objs: list[T],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
r"""Generates embeddings for the given texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (list[T]): The objects for which to generate the embeddings.
|
|
||||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[float]]: A list that represents the
|
|
||||||
generated embedding as a list of floating-point numbers.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def embed(
|
|
||||||
self,
|
|
||||||
obj: T,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[float]:
|
|
||||||
r"""Generates an embedding for the given text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj (T): The object for which to generate the embedding.
|
|
||||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: A list of floating-point numbers representing the
|
|
||||||
generated embedding.
|
|
||||||
"""
|
|
||||||
return self.embed_list([obj], **kwargs)[0]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embedding for the current model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from camel.embeddings.base import BaseEmbedding
|
|
||||||
from camel.types import EmbeddingModelType
|
|
||||||
from camel.utils import api_keys_required
|
|
||||||
|
|
||||||
|
|
||||||
class MistralEmbedding(BaseEmbedding[str]):
|
|
||||||
r"""Provides text embedding functionalities using Mistral's models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (EmbeddingModelType, optional): The model type to be
|
|
||||||
used for text embeddings.
|
|
||||||
(default: :obj:`MISTRAL_EMBED`)
|
|
||||||
api_key (str, optional): The API key for authenticating with the
|
|
||||||
Mistral service. (default: :obj:`None`)
|
|
||||||
dimensions (int, optional): The text embedding output dimensions.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If an unsupported model type is specified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
|
|
||||||
api_key: str | None = None,
|
|
||||||
dimensions: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
from mistralai import Mistral
|
|
||||||
|
|
||||||
if not model_type.is_mistral:
|
|
||||||
raise ValueError("Invalid Mistral embedding model type.")
|
|
||||||
self.model_type = model_type
|
|
||||||
if dimensions is None:
|
|
||||||
self.output_dim = model_type.output_dim
|
|
||||||
else:
|
|
||||||
assert isinstance(dimensions, int)
|
|
||||||
self.output_dim = dimensions
|
|
||||||
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
|
|
||||||
self._client = Mistral(api_key=self._api_key)
|
|
||||||
|
|
||||||
@api_keys_required("MISTRAL_API_KEY")
|
|
||||||
def embed_list(
|
|
||||||
self,
|
|
||||||
objs: list[str],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
r"""Generates embeddings for the given texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (list[str]): The texts for which to generate the embeddings.
|
|
||||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[float]]: A list that represents the generated embedding
|
|
||||||
as a list of floating-point numbers.
|
|
||||||
"""
|
|
||||||
# TODO: count tokens
|
|
||||||
response = self._client.embeddings.create(
|
|
||||||
inputs=objs,
|
|
||||||
model=self.model_type.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return [data.embedding for data in response.data] # type: ignore[misc,union-attr]
|
|
||||||
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embedding for the current model.
|
|
||||||
"""
|
|
||||||
return self.output_dim
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from camel.embeddings.base import BaseEmbedding
|
|
||||||
from camel.utils import api_keys_required
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatibleEmbedding(BaseEmbedding[str]):
|
|
||||||
r"""Provides text embedding functionalities supporting OpenAI
|
|
||||||
compatibility.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (str): The model type to be used for text embeddings.
|
|
||||||
api_key (str): The API key for authenticating with the model service.
|
|
||||||
url (str): The url to the model service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: str,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.model_type = model_type
|
|
||||||
self.output_dim: Optional[int] = None
|
|
||||||
|
|
||||||
self._api_key = api_key or os.environ.get(
|
|
||||||
"OPENAI_COMPATIBILIY_API_KEY"
|
|
||||||
)
|
|
||||||
self._url = url or os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL")
|
|
||||||
self._client = OpenAI(
|
|
||||||
timeout=60,
|
|
||||||
max_retries=3,
|
|
||||||
api_key=self._api_key,
|
|
||||||
base_url=self._url,
|
|
||||||
)
|
|
||||||
|
|
||||||
@api_keys_required("OPENAI_COMPATIBILIY_API_KEY")
|
|
||||||
def embed_list(
|
|
||||||
self,
|
|
||||||
objs: list[str],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
r"""Generates embeddings for the given texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (list[str]): The texts for which to generate the embeddings.
|
|
||||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[float]]: A list that represents the generated embedding
|
|
||||||
as a list of floating-point numbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
response = self._client.embeddings.create(
|
|
||||||
input=objs,
|
|
||||||
model=self.model_type,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.output_dim = len(response.data[0].embedding)
|
|
||||||
return [data.embedding for data in response.data]
|
|
||||||
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embedding for the current model.
|
|
||||||
"""
|
|
||||||
if self.output_dim is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Output dimension is not yet determined. Call "
|
|
||||||
"'embed_list' first."
|
|
||||||
)
|
|
||||||
return self.output_dim
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from camel.embeddings.base import BaseEmbedding
|
|
||||||
from camel.types import NOT_GIVEN, EmbeddingModelType, NotGiven
|
|
||||||
from camel.utils import api_keys_required
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbedding(BaseEmbedding[str]):
|
|
||||||
r"""Provides text embedding functionalities using OpenAI's models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (EmbeddingModelType, optional): The model type to be
|
|
||||||
used for text embeddings.
|
|
||||||
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
|
|
||||||
api_key (str, optional): The API key for authenticating with the
|
|
||||||
OpenAI service. (default: :obj:`None`)
|
|
||||||
dimensions (int, optional): The text embedding output dimensions.
|
|
||||||
(default: :obj:`NOT_GIVEN`)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If an unsupported model type is specified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: EmbeddingModelType = (
|
|
||||||
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
|
|
||||||
),
|
|
||||||
api_key: str | None = None,
|
|
||||||
dimensions: int | NotGiven = NOT_GIVEN,
|
|
||||||
) -> None:
|
|
||||||
if not model_type.is_openai:
|
|
||||||
raise ValueError("Invalid OpenAI embedding model type.")
|
|
||||||
self.model_type = model_type
|
|
||||||
if dimensions == NOT_GIVEN:
|
|
||||||
self.output_dim = model_type.output_dim
|
|
||||||
else:
|
|
||||||
assert isinstance(dimensions, int)
|
|
||||||
self.output_dim = dimensions
|
|
||||||
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
||||||
self.client = OpenAI(timeout=60, max_retries=3, api_key=self._api_key)
|
|
||||||
|
|
||||||
@api_keys_required("OPENAI_API_KEY")
|
|
||||||
def embed_list(
|
|
||||||
self,
|
|
||||||
objs: list[str],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
r"""Generates embeddings for the given texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (list[str]): The texts for which to generate the embeddings.
|
|
||||||
**kwargs (Any): Extra kwargs passed to the embedding API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[float]]: A list that represents the generated embedding
|
|
||||||
as a list of floating-point numbers.
|
|
||||||
"""
|
|
||||||
# TODO: count tokens
|
|
||||||
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
|
|
||||||
response = self.client.embeddings.create(
|
|
||||||
input=objs,
|
|
||||||
model=self.model_type.value,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = self.client.embeddings.create(
|
|
||||||
input=objs,
|
|
||||||
model=self.model_type.value,
|
|
||||||
dimensions=self.output_dim,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return [data.embedding for data in response.data]
|
|
||||||
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embedding for the current model.
|
|
||||||
"""
|
|
||||||
return self.output_dim
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from numpy import ndarray
|
|
||||||
|
|
||||||
from camel.embeddings.base import BaseEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEncoder(BaseEmbedding[str]):
|
|
||||||
r"""This class provides functionalities to generate text
|
|
||||||
embeddings using `Sentence Transformers`.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://www.sbert.net/
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "intfloat/e5-large-v2",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
r"""Initializes the: obj: `SentenceTransformerEmbedding` class
|
|
||||||
with the specified transformer model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str, optional): The name of the model to use.
|
|
||||||
(default: :obj:`intfloat/e5-large-v2`)
|
|
||||||
**kwargs (optional): Additional arguments of
|
|
||||||
:class:`SentenceTransformer`, such as :obj:`prompts` etc.
|
|
||||||
"""
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
self.model = SentenceTransformer(model_name, **kwargs)
|
|
||||||
|
|
||||||
def embed_list(
|
|
||||||
self,
|
|
||||||
objs: list[str],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
r"""Generates embeddings for the given texts using the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (list[str]): The texts for which to generate the
|
|
||||||
embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[float]]: A list that represents the generated embedding
|
|
||||||
as a list of floating-point numbers.
|
|
||||||
"""
|
|
||||||
if not objs:
|
|
||||||
raise ValueError("Input text list is empty")
|
|
||||||
embeddings = self.model.encode(
|
|
||||||
objs, normalize_embeddings=True, **kwargs
|
|
||||||
)
|
|
||||||
assert isinstance(embeddings, ndarray)
|
|
||||||
return embeddings.tolist()
|
|
||||||
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embeddings.
|
|
||||||
"""
|
|
||||||
output_dim = self.model.get_sentence_embedding_dimension()
|
|
||||||
assert isinstance(output_dim, int)
|
|
||||||
return output_dim
|
|
||||||
@@ -1,149 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from camel.embeddings import BaseEmbedding
|
|
||||||
from camel.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]):
|
|
||||||
r"""Provides image embedding functionalities using multimodal model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name : The model type to be used for generating embeddings.
|
|
||||||
And the default value is: obj:`openai/clip-vit-base-patch32`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If an unsupported model type is specified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model_name: str = "openai/clip-vit-base-patch32"
|
|
||||||
) -> None:
|
|
||||||
r"""Initializes the: obj: `VisionLanguageEmbedding` class with a
|
|
||||||
specified model and return the dimension of embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str, optional): The version name of the model to use.
|
|
||||||
(default: :obj:`openai/clip-vit-base-patch32`)
|
|
||||||
"""
|
|
||||||
from transformers import AutoModel, AutoProcessor
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model = AutoModel.from_pretrained(model_name)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
|
|
||||||
|
|
||||||
self.valid_processor_kwargs = []
|
|
||||||
self.valid_model_kwargs = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.valid_processor_kwargs = (
|
|
||||||
self.processor.image_processor._valid_processor_keys
|
|
||||||
)
|
|
||||||
self.valid_model_kwargs = [
|
|
||||||
"pixel_values",
|
|
||||||
"return_dict",
|
|
||||||
"interpolate_pos_encoding",
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
logger.warning("not typically processor and model structure")
|
|
||||||
pass
|
|
||||||
self.dim: Optional[int] = None
|
|
||||||
|
|
||||||
def embed_list(
|
|
||||||
self, objs: List[Union[Image.Image, str]], **kwargs: Any
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Generates embeddings for the given images or texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
objs (List[Image.Image|str]): The list of images or texts for
|
|
||||||
which to generate the embeddings.
|
|
||||||
image_processor_kwargs: Extra kwargs passed to the image processor.
|
|
||||||
tokenizer_kwargs: Extra kwargs passed to the text tokenizer
|
|
||||||
(processor).
|
|
||||||
model_kwargs: Extra kwargs passed to the main model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: A list that represents the generated embedding
|
|
||||||
as a list of floating-point numbers.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the input type is not `Image.Image` or `str`.
|
|
||||||
"""
|
|
||||||
if not objs:
|
|
||||||
raise ValueError("Input objs list is empty.")
|
|
||||||
|
|
||||||
image_processor_kwargs: Optional[dict] = kwargs.get(
|
|
||||||
'image_processor_kwargs', {}
|
|
||||||
)
|
|
||||||
tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {})
|
|
||||||
model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {})
|
|
||||||
|
|
||||||
result_list = []
|
|
||||||
for obj in objs:
|
|
||||||
if isinstance(obj, Image.Image):
|
|
||||||
image_input = self.processor(
|
|
||||||
images=obj,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
**image_processor_kwargs,
|
|
||||||
)
|
|
||||||
image_feature = (
|
|
||||||
self.model.get_image_features(
|
|
||||||
**image_input, **model_kwargs
|
|
||||||
)
|
|
||||||
.squeeze(dim=0)
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
result_list.append(image_feature)
|
|
||||||
elif isinstance(obj, str):
|
|
||||||
text_input = self.processor(
|
|
||||||
text=obj,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
text_feature = (
|
|
||||||
self.model.get_text_features(**text_input, **model_kwargs)
|
|
||||||
.squeeze(dim=0)
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
result_list.append(text_feature)
|
|
||||||
else:
|
|
||||||
raise ValueError("Input type is not image nor text.")
|
|
||||||
|
|
||||||
self.dim = len(result_list[0])
|
|
||||||
|
|
||||||
if any(len(result) != self.dim for result in result_list):
|
|
||||||
raise ValueError("Dimensionality is not consistent.")
|
|
||||||
|
|
||||||
return result_list
|
|
||||||
|
|
||||||
def get_output_dim(self) -> int:
|
|
||||||
r"""Returns the output dimension of the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The dimensionality of the embedding for the current model.
|
|
||||||
"""
|
|
||||||
if self.dim is None:
|
|
||||||
text = 'dimension'
|
|
||||||
inputs = self.processor(text=[text], return_tensors="pt")
|
|
||||||
self.dim = self.model.get_text_features(**inputs).shape[1]
|
|
||||||
return self.dim
|
|
||||||
@@ -1,375 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
|
||||||
from camel.types import RoleType, TaskType
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageGenerator:
|
|
||||||
r"""System message generator for agents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_type (TaskType, optional): The task type.
|
|
||||||
(default: :obj:`TaskType.AI_SOCIETY`)
|
|
||||||
sys_prompts (Optional[Dict[RoleType, str]], optional): The prompts of
|
|
||||||
the system messages for each role type. (default: :obj:`None`)
|
|
||||||
sys_msg_meta_dict_keys (Optional[Set[str]], optional): The set of keys
|
|
||||||
of the meta dictionary used to fill the prompts.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
|
||||||
sys_prompts: Optional[Dict[RoleType, str]] = None,
|
|
||||||
sys_msg_meta_dict_keys: Optional[Set[str]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.sys_prompts: Dict[RoleType, str]
|
|
||||||
|
|
||||||
if sys_prompts is not None:
|
|
||||||
self.sys_prompts = sys_prompts
|
|
||||||
self.sys_msg_meta_dict_keys = sys_msg_meta_dict_keys or set()
|
|
||||||
else:
|
|
||||||
assistant_prompt_template = (
|
|
||||||
PromptTemplateGenerator().get_system_prompt(
|
|
||||||
task_type,
|
|
||||||
RoleType.ASSISTANT,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_prompt_template = PromptTemplateGenerator().get_system_prompt(
|
|
||||||
task_type,
|
|
||||||
RoleType.USER,
|
|
||||||
)
|
|
||||||
critic_prompt_template = (
|
|
||||||
PromptTemplateGenerator().get_system_prompt(
|
|
||||||
task_type,
|
|
||||||
RoleType.CRITIC,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
embodiment_prompt_template = (
|
|
||||||
PromptTemplateGenerator().get_system_prompt(
|
|
||||||
task_type,
|
|
||||||
RoleType.EMBODIMENT,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sys_prompts = dict()
|
|
||||||
self.sys_prompts[RoleType.ASSISTANT] = assistant_prompt_template
|
|
||||||
self.sys_prompts[RoleType.USER] = user_prompt_template
|
|
||||||
self.sys_prompts[RoleType.CRITIC] = critic_prompt_template
|
|
||||||
self.sys_prompts[RoleType.EMBODIMENT] = embodiment_prompt_template
|
|
||||||
|
|
||||||
self.sys_msg_meta_dict_keys = (
|
|
||||||
assistant_prompt_template.key_words
|
|
||||||
| user_prompt_template.key_words
|
|
||||||
| critic_prompt_template.key_words
|
|
||||||
| embodiment_prompt_template.key_words
|
|
||||||
)
|
|
||||||
|
|
||||||
if RoleType.DEFAULT not in self.sys_prompts:
|
|
||||||
self.sys_prompts[RoleType.DEFAULT] = "You are a helpful assistant."
|
|
||||||
|
|
||||||
def validate_meta_dict_keys(self, meta_dict: Dict[str, str]) -> None:
|
|
||||||
r"""Validates the keys of the meta_dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
meta_dict (Dict[str, str]): The dictionary to validate.
|
|
||||||
"""
|
|
||||||
if not set(meta_dict.keys()).issubset(self.sys_msg_meta_dict_keys):
|
|
||||||
raise ValueError(
|
|
||||||
"The keys of the meta_dict should be in "
|
|
||||||
f"{self.sys_msg_meta_dict_keys}. "
|
|
||||||
f"Got {set(meta_dict.keys())} instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_dict(
|
|
||||||
self,
|
|
||||||
meta_dict: Dict[str, str],
|
|
||||||
role_tuple: Tuple[str, RoleType] = ("", RoleType.DEFAULT),
|
|
||||||
) -> BaseMessage:
|
|
||||||
r"""Generates a system message from a dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
meta_dict (Dict[str, str]): The dictionary containing the
|
|
||||||
information to generate the system message.
|
|
||||||
role_tuple (Tuple[str, RoleType], optional): The tuple containing
|
|
||||||
the role name and role type. (default: ("", RoleType.DEFAULT))
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessage: The generated system message.
|
|
||||||
"""
|
|
||||||
self.validate_meta_dict_keys(meta_dict)
|
|
||||||
role_name, role_type = role_tuple
|
|
||||||
sys_prompt = self.sys_prompts[role_type]
|
|
||||||
sys_prompt = sys_prompt.format(**meta_dict)
|
|
||||||
return BaseMessage(
|
|
||||||
role_name=role_name,
|
|
||||||
role_type=role_type,
|
|
||||||
meta_dict=meta_dict,
|
|
||||||
content=sys_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_dicts(
|
|
||||||
self,
|
|
||||||
meta_dicts: List[Dict[str, str]],
|
|
||||||
role_tuples: List[Tuple[str, RoleType]],
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
r"""Generates a list of system messages from a list of dictionaries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
meta_dicts (List[Dict[str, str]]): A list of dictionaries
|
|
||||||
containing the information to generate the system messages.
|
|
||||||
role_tuples (List[Tuple[str, RoleType]]): A list of tuples
|
|
||||||
containing the role name and role type for each system message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[BaseMessage]: A list of generated system messages.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the number of meta_dicts and role_tuples are
|
|
||||||
different.
|
|
||||||
"""
|
|
||||||
if len(meta_dicts) != len(role_tuples):
|
|
||||||
raise ValueError(
|
|
||||||
"The number of meta_dicts and role_types should be the same."
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
self.from_dict(meta_dict, role_tuple)
|
|
||||||
for meta_dict, role_tuple in zip(meta_dicts, role_tuples)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class RoleNameGenerator:
|
|
||||||
r"""Role name generator for role-playing workers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
assistant_role_names_path (str, optional): The path to the file
|
|
||||||
containing the assistant role names.
|
|
||||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
|
||||||
user_role_names_path (str, optional): The path to the file
|
|
||||||
containing the user role names.
|
|
||||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
|
||||||
assistant_role_names (Optional[List[str]], optional): The list of
|
|
||||||
assistant role names. (default: :obj:`None`)
|
|
||||||
user_role_names (Optional[List[str]], optional): The list of user role
|
|
||||||
names. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
|
||||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
|
||||||
assistant_role_names: Optional[List[str]] = None,
|
|
||||||
user_role_names: Optional[List[str]] = None,
|
|
||||||
) -> None:
|
|
||||||
if assistant_role_names is None:
|
|
||||||
with open(assistant_role_names_path, "r") as f:
|
|
||||||
assistant_role_names_: List[str] = f.read().splitlines()
|
|
||||||
self.assistant_role_names = [
|
|
||||||
" ".join(name.split(" ")[1:])
|
|
||||||
for name in assistant_role_names_
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.assistant_role_names = assistant_role_names
|
|
||||||
|
|
||||||
if user_role_names is None:
|
|
||||||
with open(user_role_names_path, "r") as f:
|
|
||||||
user_role_names_: List[str] = f.read().splitlines()
|
|
||||||
self.user_role_names = [
|
|
||||||
" ".join(name.split(" ")[1:]) for name in user_role_names_
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.user_role_names = user_role_names
|
|
||||||
|
|
||||||
def from_role_files(self) -> Generator[Tuple, None, None]:
|
|
||||||
r"""Generate role names from the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[Tuple, None, None]: A generator that yields tuples of
|
|
||||||
assistant role names and user role names.
|
|
||||||
"""
|
|
||||||
for assistant_role_name in self.assistant_role_names:
|
|
||||||
for user_role_name in self.user_role_names:
|
|
||||||
yield (assistant_role_name, user_role_name)
|
|
||||||
|
|
||||||
|
|
||||||
class AISocietyTaskPromptGenerator:
|
|
||||||
r"""Task prompt generator for AI society tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_tasks (int, optional): The number of tasks to generate.
|
|
||||||
(default: :obj:`10`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_tasks: int = 10,
|
|
||||||
) -> None:
|
|
||||||
self.generate_tasks_prompt = (
|
|
||||||
PromptTemplateGenerator().get_generate_tasks_prompt(
|
|
||||||
TaskType.AI_SOCIETY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_tasks = num_tasks
|
|
||||||
|
|
||||||
# TODO: Return role names for user and assistant with the generator.
|
|
||||||
def from_role_files(
|
|
||||||
self,
|
|
||||||
assistant_role_names_path: str = "data/ai_society/assistant_roles.txt",
|
|
||||||
user_role_names_path: str = "data/ai_society/user_roles.txt",
|
|
||||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
|
||||||
r"""Generate tasks from role files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
assistant_role_names_path (str, optional): The path to the file
|
|
||||||
containing the assistant role names.
|
|
||||||
(default: :obj:`"data/ai_society/assistant_roles.txt"`)
|
|
||||||
user_role_names_path (str, optional): The path to the file
|
|
||||||
containing the user role names.
|
|
||||||
(default: :obj:`"data/ai_society/user_roles.txt"`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
|
||||||
that yields tuples of task prompts and role names.
|
|
||||||
"""
|
|
||||||
roles_generator = RoleNameGenerator(
|
|
||||||
assistant_role_names_path, user_role_names_path
|
|
||||||
).from_role_files()
|
|
||||||
for role_1, role_2 in roles_generator:
|
|
||||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
|
||||||
assistant_role=role_1,
|
|
||||||
user_role=role_2,
|
|
||||||
num_tasks=self.num_tasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield (generate_tasks_prompt, (role_1, role_2))
|
|
||||||
|
|
||||||
def from_role_generator(
|
|
||||||
self, role_generator: Generator[Tuple, None, None]
|
|
||||||
) -> Generator[Tuple[str, Tuple[str, str]], None, None]:
|
|
||||||
r"""Generate tasks from a role generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_generator (Generator[Tuple, None, None]): A generator that
|
|
||||||
yields tuples of role names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[Tuple[str, Tuple[str, str]], None, None]: A generator
|
|
||||||
that yields tuples of task prompts and role names.
|
|
||||||
"""
|
|
||||||
for role_1, role_2 in role_generator:
|
|
||||||
generate_tasks_prompt = self.generate_tasks_prompt.format(
|
|
||||||
assistant_role=role_1,
|
|
||||||
user_role=role_2,
|
|
||||||
num_tasks=self.num_tasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield (generate_tasks_prompt, (role_1, role_2))
|
|
||||||
|
|
||||||
|
|
||||||
class SingleTxtGenerator:
|
|
||||||
r"""Single text generator for role-playing workers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_file_path (str): The path to the file containing the text data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_file_path: str,
|
|
||||||
) -> None:
|
|
||||||
with open(text_file_path, "r") as f:
|
|
||||||
data_list: List[str] = f.read().splitlines()
|
|
||||||
self.data_list = [
|
|
||||||
" ".join(name.split(" ")[1:]) for name in data_list
|
|
||||||
]
|
|
||||||
|
|
||||||
def from_role_files(self) -> Generator[str, None, None]:
|
|
||||||
r"""Generate text from the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[str, None, None]: A generator that yields the text data.
|
|
||||||
"""
|
|
||||||
for data in self.data_list:
|
|
||||||
yield data
|
|
||||||
|
|
||||||
|
|
||||||
class CodeTaskPromptGenerator:
|
|
||||||
r"""Code task prompt generator for code tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_tasks (int, optional): The number of tasks to generate.
|
|
||||||
(default: :obj:`50`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_tasks: int = 50,
|
|
||||||
) -> None:
|
|
||||||
self.generate_tasks_prompt = (
|
|
||||||
PromptTemplateGenerator().get_generate_tasks_prompt(TaskType.CODE)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_tasks = num_tasks
|
|
||||||
|
|
||||||
def from_role_files(
|
|
||||||
self,
|
|
||||||
languages_path: str = "data/code/languages.txt",
|
|
||||||
domains_path: str = "data/code/domains.txt",
|
|
||||||
) -> Generator[Tuple[TextPrompt, str, str], None, None]:
|
|
||||||
r"""Generate tasks from role files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
languages_path (str, optional): The path to the file containing
|
|
||||||
the language names. (default: :obj:`"data/code/languages.txt"`)
|
|
||||||
domains_path (str, optional): The path to the file containing
|
|
||||||
the domain names. (default: :obj:`"data/code/domains.txt"`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[Tuple[TextPrompt, str, str], None, None]: A generator
|
|
||||||
that yields tuples of task prompts, language names, and domain
|
|
||||||
names.
|
|
||||||
"""
|
|
||||||
language_generator = SingleTxtGenerator(
|
|
||||||
languages_path
|
|
||||||
).from_role_files()
|
|
||||||
|
|
||||||
for language in language_generator:
|
|
||||||
domains_generator = SingleTxtGenerator(
|
|
||||||
domains_path
|
|
||||||
).from_role_files()
|
|
||||||
for domain in domains_generator:
|
|
||||||
generated_tasks_prompt = self.generate_tasks_prompt.format(
|
|
||||||
language=language, domain=domain, num_tasks=self.num_tasks
|
|
||||||
)
|
|
||||||
yield generated_tasks_prompt, language, domain
|
|
||||||
|
|
||||||
def from_role_generator(
|
|
||||||
self, role_generator: Generator[Tuple, None, None]
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
r"""Generate tasks from a role generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_generator (Generator[Tuple, None, None]): A generator that
|
|
||||||
yields tuples of role names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator[str, None, None]: A generator that yields the task
|
|
||||||
prompts.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Any, Dict, Sequence
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from camel.messages import BaseMessage
|
|
||||||
from camel.responses import ChatAgentResponse
|
|
||||||
from camel.utils import print_text_animated
|
|
||||||
|
|
||||||
|
|
||||||
class Human:
|
|
||||||
r"""A class representing a human user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the human user.
|
|
||||||
(default: :obj:`"Kill Switch Engineer"`).
|
|
||||||
logger_color (Any): The color of the menu options displayed to the
|
|
||||||
user. (default: :obj:`Fore.MAGENTA`)
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name (str): The name of the human user.
|
|
||||||
logger_color (Any): The color of the menu options displayed to the
|
|
||||||
user.
|
|
||||||
input_button (str): The text displayed for the input button.
|
|
||||||
kill_button (str): The text displayed for the kill button.
|
|
||||||
options_dict (Dict[str, str]): A dictionary containing the options
|
|
||||||
displayed to the user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = "Kill Switch Engineer",
|
|
||||||
logger_color: Any = Fore.MAGENTA,
|
|
||||||
) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.logger_color = logger_color
|
|
||||||
self.input_button = f"Input by {self.name}."
|
|
||||||
self.kill_button = "Stop!!!"
|
|
||||||
self.options_dict: Dict[str, str] = dict()
|
|
||||||
|
|
||||||
def display_options(self, messages: Sequence[BaseMessage]) -> None:
|
|
||||||
r"""Displays the options to the user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
options = [message.content for message in messages]
|
|
||||||
options.append(self.input_button)
|
|
||||||
options.append(self.kill_button)
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + "\n> Proposals from "
|
|
||||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
|
||||||
"Please choose an option:\n"
|
|
||||||
)
|
|
||||||
for index, option in enumerate(options):
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color
|
|
||||||
+ f"\x1b[3mOption {index + 1}:\n{option}\x1b[0m\n"
|
|
||||||
)
|
|
||||||
self.options_dict[str(index + 1)] = option
|
|
||||||
|
|
||||||
def get_input(self) -> str:
|
|
||||||
r"""Gets the input from the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The user's input.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
human_input = input(
|
|
||||||
self.logger_color
|
|
||||||
+ f"Please enter your choice ([1-{len(self.options_dict)}]): "
|
|
||||||
)
|
|
||||||
print("\n")
|
|
||||||
if human_input in self.options_dict:
|
|
||||||
break
|
|
||||||
print_text_animated(
|
|
||||||
self.logger_color + "\n> Invalid choice. Please try again.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
return human_input
|
|
||||||
|
|
||||||
def parse_input(self, human_input: str) -> str:
|
|
||||||
r"""Parses the user's input and returns a `BaseMessage` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
human_input (str): The user's input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
content: A `str` object representing the user's input.
|
|
||||||
"""
|
|
||||||
if self.options_dict[human_input] == self.input_button:
|
|
||||||
content = input(self.logger_color + "Please enter your message: ")
|
|
||||||
elif self.options_dict[human_input] == self.kill_button:
|
|
||||||
exit(self.logger_color + f"Killed by {self.name}.")
|
|
||||||
else:
|
|
||||||
content = self.options_dict[human_input]
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
def reduce_step(
|
|
||||||
self, messages: Sequence[BaseMessage]
|
|
||||||
) -> ChatAgentResponse:
|
|
||||||
r"""Performs one step of the conversation by displaying options to the
|
|
||||||
user, getting their input, and parsing their choice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatAgentResponse: A `ChatAgentResponse` object representing the
|
|
||||||
user's choice.
|
|
||||||
"""
|
|
||||||
meta_chat_message = BaseMessage(
|
|
||||||
role_name=messages[0].role_name,
|
|
||||||
role_type=messages[0].role_type,
|
|
||||||
meta_dict=messages[0].meta_dict,
|
|
||||||
content="",
|
|
||||||
)
|
|
||||||
self.display_options(messages)
|
|
||||||
human_input = self.get_input()
|
|
||||||
content = self.parse_input(human_input)
|
|
||||||
message = meta_chat_message.create_new_instance(content)
|
|
||||||
return ChatAgentResponse(msgs=[message], terminated=False, info={})
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .base import BaseInterpreter
|
|
||||||
from .docker_interpreter import DockerInterpreter
|
|
||||||
from .internal_python_interpreter import InternalPythonInterpreter
|
|
||||||
from .interpreter_error import InterpreterError
|
|
||||||
from .ipython_interpreter import JupyterKernelInterpreter
|
|
||||||
from .subprocess_interpreter import SubprocessInterpreter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseInterpreter',
|
|
||||||
'InterpreterError',
|
|
||||||
'InternalPythonInterpreter',
|
|
||||||
'SubprocessInterpreter',
|
|
||||||
'DockerInterpreter',
|
|
||||||
'JupyterKernelInterpreter',
|
|
||||||
]
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
class BaseInterpreter(ABC):
|
|
||||||
r"""An abstract base class for code interpreters."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run(self, code: str, code_type: str) -> str:
|
|
||||||
r"""Executes the given code based on its type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The code to be executed.
|
|
||||||
code_type (str): The type of the code, which must be one of the
|
|
||||||
types returned by `supported_code_types()`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The result of the code execution. If the execution fails, this
|
|
||||||
should include sufficient information to diagnose and correct
|
|
||||||
the issue.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InterpreterError: If the code execution encounters errors that
|
|
||||||
could be resolved by modifying or regenerating the code.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def supported_code_types(self) -> List[str]:
|
|
||||||
r"""Provides supported code types by the interpreter."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
|
||||||
r"""Updates action space for *python* interpreter"""
|
|
||||||
pass
|
|
||||||
@@ -1,245 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import io
|
|
||||||
import shlex
|
|
||||||
import tarfile
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from camel.interpreters.base import BaseInterpreter
|
|
||||||
from camel.interpreters.interpreter_error import InterpreterError
|
|
||||||
from camel.logger import get_logger
|
|
||||||
from camel.utils import is_docker_running
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from docker.models.containers import Container
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DockerInterpreter(BaseInterpreter):
|
|
||||||
r"""A class for executing code files or code strings in a docker container.
|
|
||||||
|
|
||||||
This class handles the execution of code in different scripting languages
|
|
||||||
(currently Python and Bash) within a docker container, capturing their
|
|
||||||
stdout and stderr streams, and allowing user checking before executing code
|
|
||||||
strings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_confirm (bool, optional): If `True`, prompt user before
|
|
||||||
running code strings for security. Defaults to `True`.
|
|
||||||
print_stdout (bool, optional): If `True`, print the standard
|
|
||||||
output of the executed code. Defaults to `False`.
|
|
||||||
print_stderr (bool, optional): If `True`, print the standard error
|
|
||||||
of the executed code. Defaults to `True`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "python {file_name}",
|
|
||||||
"bash": "bash {file_name}",
|
|
||||||
}
|
|
||||||
|
|
||||||
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "py",
|
|
||||||
"bash": "sh",
|
|
||||||
}
|
|
||||||
|
|
||||||
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "python",
|
|
||||||
"py3": "python",
|
|
||||||
"python3": "python",
|
|
||||||
"py": "python",
|
|
||||||
"shell": "bash",
|
|
||||||
"bash": "bash",
|
|
||||||
"sh": "bash",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
require_confirm: bool = True,
|
|
||||||
print_stdout: bool = False,
|
|
||||||
print_stderr: bool = True,
|
|
||||||
) -> None:
|
|
||||||
self.require_confirm = require_confirm
|
|
||||||
self.print_stdout = print_stdout
|
|
||||||
self.print_stderr = print_stderr
|
|
||||||
|
|
||||||
# lazy initialization of container
|
|
||||||
self._container: Optional[Container] = None
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
r"""Destructor for the DockerInterpreter class.
|
|
||||||
|
|
||||||
This method ensures that the Docker container is removed when the
|
|
||||||
interpreter is deleted.
|
|
||||||
"""
|
|
||||||
if self._container is not None:
|
|
||||||
self._container.remove(force=True)
|
|
||||||
|
|
||||||
def _initialize_if_needed(self) -> None:
|
|
||||||
if self._container is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not is_docker_running():
|
|
||||||
raise InterpreterError(
|
|
||||||
"Docker daemon is not running. Please install/start docker "
|
|
||||||
"and try again."
|
|
||||||
)
|
|
||||||
|
|
||||||
import docker
|
|
||||||
|
|
||||||
client = docker.from_env()
|
|
||||||
self._container = client.containers.run(
|
|
||||||
"python:3.10",
|
|
||||||
detach=True,
|
|
||||||
name=f"camel-interpreter-{uuid.uuid4()}",
|
|
||||||
command="tail -f /dev/null",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_file_in_container(self, content: str) -> Path:
|
|
||||||
# get a random name for the file
|
|
||||||
filename = str(uuid.uuid4())
|
|
||||||
# create a tar in memory
|
|
||||||
tar_stream = io.BytesIO()
|
|
||||||
with tarfile.open(fileobj=tar_stream, mode='w') as tar:
|
|
||||||
tarinfo = tarfile.TarInfo(name=filename)
|
|
||||||
tarinfo.size = len(content)
|
|
||||||
tar.addfile(tarinfo, io.BytesIO(content.encode('utf-8')))
|
|
||||||
tar_stream.seek(0)
|
|
||||||
|
|
||||||
# copy the tar into the container
|
|
||||||
if self._container is None:
|
|
||||||
raise InterpreterError(
|
|
||||||
"Container is not initialized. Try running the code again."
|
|
||||||
)
|
|
||||||
self._container.put_archive("/tmp", tar_stream)
|
|
||||||
return Path(f"/tmp/{filename}")
|
|
||||||
|
|
||||||
def _run_file_in_container(
|
|
||||||
self,
|
|
||||||
file: Path,
|
|
||||||
code_type: str,
|
|
||||||
) -> str:
|
|
||||||
code_type = self._check_code_type(code_type)
|
|
||||||
commands = shlex.split(
|
|
||||||
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
|
|
||||||
file_name=file.as_posix()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if self._container is None:
|
|
||||||
raise InterpreterError(
|
|
||||||
"Container is not initialized. Try running the code again."
|
|
||||||
)
|
|
||||||
stdout, stderr = self._container.exec_run(
|
|
||||||
commands,
|
|
||||||
demux=True,
|
|
||||||
).output
|
|
||||||
|
|
||||||
if self.print_stdout and stdout:
|
|
||||||
print("======stdout======")
|
|
||||||
print(Fore.GREEN + stdout.decode() + Fore.RESET)
|
|
||||||
print("==================")
|
|
||||||
if self.print_stderr and stderr:
|
|
||||||
print("======stderr======")
|
|
||||||
print(Fore.RED + stderr.decode() + Fore.RESET)
|
|
||||||
print("==================")
|
|
||||||
exec_result = f"{stdout.decode()}" if stdout else ""
|
|
||||||
exec_result += f"(stderr: {stderr.decode()})" if stderr else ""
|
|
||||||
return exec_result
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
code_type: str,
|
|
||||||
) -> str:
|
|
||||||
r"""Executes the given code in the conatiner attached to the
|
|
||||||
interpreter, and captures the stdout and stderr streams.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The code string to execute.
|
|
||||||
code_type (str): The type of code to execute (e.g., 'python',
|
|
||||||
'bash').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string containing the captured stdout and stderr of the
|
|
||||||
executed code.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InterpreterError: If the user declines to run the code, or the
|
|
||||||
code type is unsupported, or there is an error in the docker
|
|
||||||
API/container
|
|
||||||
"""
|
|
||||||
import docker.errors
|
|
||||||
|
|
||||||
code_type = self._check_code_type(code_type)
|
|
||||||
|
|
||||||
# Print code for security checking
|
|
||||||
if self.require_confirm:
|
|
||||||
logger.info(
|
|
||||||
f"The following {code_type} code will run on your "
|
|
||||||
"computer: {code}"
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
choice = input("Running code? [Y/n]:").lower()
|
|
||||||
if choice in ["y", "yes", "ye", ""]:
|
|
||||||
break
|
|
||||||
elif choice not in ["no", "n"]:
|
|
||||||
continue
|
|
||||||
raise InterpreterError(
|
|
||||||
"Execution halted: User opted not to run the code. "
|
|
||||||
"This choice stops the current operation and any "
|
|
||||||
"further code execution."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._initialize_if_needed()
|
|
||||||
|
|
||||||
try:
|
|
||||||
temp_file_path = self._create_file_in_container(code)
|
|
||||||
result = self._run_file_in_container(temp_file_path, code_type)
|
|
||||||
except docker.errors.APIError as e:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Execution halted due to docker API error: {e.explanation}. "
|
|
||||||
"This choice stops the current operation and any "
|
|
||||||
"further code execution."
|
|
||||||
) from e
|
|
||||||
except docker.errors.DockerException as e:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Execution halted due to docker exceptoin: {e}. "
|
|
||||||
"This choice stops the current operation and any "
|
|
||||||
"further code execution."
|
|
||||||
) from e
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _check_code_type(self, code_type: str) -> str:
|
|
||||||
if code_type not in self._CODE_TYPE_MAPPING:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Unsupported code type {code_type}. Currently "
|
|
||||||
f"`{self.__class__.__name__}` only supports "
|
|
||||||
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
|
|
||||||
)
|
|
||||||
return self._CODE_TYPE_MAPPING[code_type]
|
|
||||||
|
|
||||||
def supported_code_types(self) -> List[str]:
|
|
||||||
r"""Provides supported code types by the interpreter."""
|
|
||||||
return list(self._CODE_EXTENSION_MAPPING.keys())
|
|
||||||
|
|
||||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
|
||||||
r"""Updates action space for *python* interpreter"""
|
|
||||||
raise RuntimeError(
|
|
||||||
"SubprocessInterpreter doesn't support " "`action_space`."
|
|
||||||
)
|
|
||||||
@@ -1,516 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import ast
|
|
||||||
import difflib
|
|
||||||
import importlib
|
|
||||||
import typing
|
|
||||||
from typing import Any, ClassVar, Dict, List, Optional
|
|
||||||
|
|
||||||
from camel.interpreters.base import BaseInterpreter
|
|
||||||
from camel.interpreters.interpreter_error import InterpreterError
|
|
||||||
|
|
||||||
|
|
||||||
class InternalPythonInterpreter(BaseInterpreter):
|
|
||||||
r"""A customized python interpreter to control the execution of
|
|
||||||
LLM-generated codes. The interpreter makes sure the code can only execute
|
|
||||||
functions given in action space and import white list. It also supports
|
|
||||||
fuzzy variable matching to retrieve uncertain input variable name.
|
|
||||||
|
|
||||||
.. highlight:: none
|
|
||||||
|
|
||||||
This class is adapted from the hugging face implementation
|
|
||||||
`python_interpreter.py <https://github.com/huggingface/transformers/blob/8f
|
|
||||||
093fb799246f7dd9104ff44728da0c53a9f67a/src/transformers/tools/python_interp
|
|
||||||
reter.py>`_. The original license applies::
|
|
||||||
|
|
||||||
Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied. See the License for the specific language governing
|
|
||||||
permissions and limitations under the License.
|
|
||||||
|
|
||||||
We have modified the original code to suit our requirements. We have
|
|
||||||
encapsulated the original functions within a class and saved the
|
|
||||||
interpreter state after execution. We have added support for "import"
|
|
||||||
statements, "for" statements, and several binary and unary operators. We
|
|
||||||
have added import white list to keep `import` statement safe. Additionally,
|
|
||||||
we have modified the variable matching logic and introduced the
|
|
||||||
:obj:`fuzz_state` for fuzzy matching.
|
|
||||||
|
|
||||||
Modifications copyright (C) 2023 CAMEL-AI.org
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_space (Dict[str, Any], optional): A dictionary that maps action
|
|
||||||
names to their corresponding functions or objects. The interpreter
|
|
||||||
can only execute functions that are either directly listed in this
|
|
||||||
dictionary or are member functions of objects listed in this
|
|
||||||
dictionary. The concept of :obj:`action_space` is derived from
|
|
||||||
EmbodiedAgent, representing the actions that an agent is capable of
|
|
||||||
performing. If `None`, set to empty dict. (default: :obj:`None`)
|
|
||||||
import_white_list (List[str], optional): A list that stores
|
|
||||||
the Python modules or functions that can be imported in the code.
|
|
||||||
All submodules and functions of the modules listed in this list are
|
|
||||||
importable. Any other import statements will be rejected. The
|
|
||||||
module and its submodule or function name are separated by a period
|
|
||||||
(:obj:`.`). (default: :obj:`None`)
|
|
||||||
unsafe_mode (bool, optional): If `True`, the interpreter runs the code
|
|
||||||
by `eval()` without any security check. (default: :obj:`False`)
|
|
||||||
raise_error (bool, optional): Raise error if the interpreter fails.
|
|
||||||
(default: :obj:`False`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CODE_TYPES: ClassVar[List[str]] = ["python", "py", "python3", "python2"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
action_space: Optional[Dict[str, Any]] = None,
|
|
||||||
import_white_list: Optional[List[str]] = None,
|
|
||||||
unsafe_mode: bool = False,
|
|
||||||
raise_error: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.action_space = action_space or dict()
|
|
||||||
self.state = self.action_space.copy()
|
|
||||||
self.fuzz_state: Dict[str, Any] = dict()
|
|
||||||
self.import_white_list = import_white_list or list()
|
|
||||||
self.raise_error = raise_error
|
|
||||||
self.unsafe_mode = unsafe_mode
|
|
||||||
|
|
||||||
def run(self, code: str, code_type: str) -> str:
|
|
||||||
r"""Executes the given code with specified code type in the
|
|
||||||
interpreter.
|
|
||||||
|
|
||||||
This method takes a string of code and its type, checks if the code
|
|
||||||
type is supported, and then executes the code. If `unsafe_mode` is
|
|
||||||
set to `False`, the code is executed in a controlled environment using
|
|
||||||
the `execute` method. If `unsafe_mode` is `True`, the code is executed
|
|
||||||
using `eval()` with the action space as the global context. An
|
|
||||||
`InterpreterError` is raised if the code type is unsupported or if any
|
|
||||||
runtime error occurs during execution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The python code to be executed.
|
|
||||||
code_type (str): The type of the code, which should be one of the
|
|
||||||
supported code types (`python`, `py`, `python3`, `python2`).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The string representation of the output of the executed code.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InterpreterError: If the `code_type` is not supported or if any
|
|
||||||
runtime error occurs during the execution of the code.
|
|
||||||
"""
|
|
||||||
if code_type not in self._CODE_TYPES:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Unsupported code type {code_type}. "
|
|
||||||
f"`{self.__class__.__name__}` only supports "
|
|
||||||
f"{', '.join(self._CODE_TYPES)}."
|
|
||||||
)
|
|
||||||
if not self.unsafe_mode:
|
|
||||||
return str(self.execute(code))
|
|
||||||
else:
|
|
||||||
return str(eval(code, self.action_space))
|
|
||||||
|
|
||||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
|
||||||
r"""Updates action space for *python* interpreter."""
|
|
||||||
self.action_space.update(action_space)
|
|
||||||
|
|
||||||
def supported_code_types(self) -> List[str]:
|
|
||||||
r"""Provides supported code types by the interpreter."""
|
|
||||||
return self._CODE_TYPES
|
|
||||||
|
|
||||||
def execute(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
state: Optional[Dict[str, Any]] = None,
|
|
||||||
fuzz_state: Optional[Dict[str, Any]] = None,
|
|
||||||
keep_state: bool = True,
|
|
||||||
) -> Any:
|
|
||||||
r"""Execute the input python codes in a security environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): Generated python code to be executed.
|
|
||||||
state (Optional[Dict[str, Any]], optional): External variables that
|
|
||||||
may be used in the generated code. (default: :obj:`None`)
|
|
||||||
fuzz_state (Optional[Dict[str, Any]], optional): External variables
|
|
||||||
that do not have certain variable names. The interpreter will
|
|
||||||
use fuzzy matching to access these variables. For example, if
|
|
||||||
:obj:`fuzz_state` has a variable :obj:`image`, the generated
|
|
||||||
code can use :obj:`input_image` to access it. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
keep_state (bool, optional): If :obj:`True`, :obj:`state` and
|
|
||||||
:obj:`fuzz_state` will be kept for later execution. Otherwise,
|
|
||||||
they will be cleared. (default: :obj:`True`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The value of the last statement (excluding "import") in the
|
|
||||||
code. For this interpreter, the value of an expression is its
|
|
||||||
value, the value of an "assign" statement is the assigned
|
|
||||||
value, and the value of an "if" and "for" block statement is
|
|
||||||
the value of the last statement in the block.
|
|
||||||
"""
|
|
||||||
if state is not None:
|
|
||||||
self.state.update(state)
|
|
||||||
if fuzz_state is not None:
|
|
||||||
self.fuzz_state.update(fuzz_state)
|
|
||||||
|
|
||||||
try:
|
|
||||||
expression = ast.parse(code)
|
|
||||||
except SyntaxError as e:
|
|
||||||
if self.raise_error:
|
|
||||||
raise InterpreterError(f"Syntax error in code: {e}")
|
|
||||||
else:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
return traceback.format_exc()
|
|
||||||
|
|
||||||
result = None
|
|
||||||
for idx, node in enumerate(expression.body):
|
|
||||||
try:
|
|
||||||
line_result = self._execute_ast(node)
|
|
||||||
except InterpreterError as e:
|
|
||||||
if not keep_state:
|
|
||||||
self.clear_state()
|
|
||||||
msg = (
|
|
||||||
f"Evaluation of the code stopped at node {idx}. "
|
|
||||||
f"See:\n{e}"
|
|
||||||
)
|
|
||||||
# More information can be provided by `ast.unparse()`,
|
|
||||||
# which is new in python 3.9.
|
|
||||||
if self.raise_error:
|
|
||||||
raise InterpreterError(msg)
|
|
||||||
else:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
return traceback.format_exc()
|
|
||||||
if line_result is not None:
|
|
||||||
result = line_result
|
|
||||||
|
|
||||||
if not keep_state:
|
|
||||||
self.clear_state()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def clear_state(self) -> None:
|
|
||||||
r"""Initialize :obj:`state` and :obj:`fuzz_state`."""
|
|
||||||
self.state = self.action_space.copy()
|
|
||||||
self.fuzz_state = {}
|
|
||||||
|
|
||||||
# ast.Index is deprecated after python 3.9, which cannot pass type check,
|
|
||||||
# but is still necessary for older versions.
|
|
||||||
@typing.no_type_check
|
|
||||||
def _execute_ast(self, expression: ast.AST) -> Any:
|
|
||||||
if isinstance(expression, ast.Assign):
|
|
||||||
# Assignment -> evaluate the assignment which should
|
|
||||||
# update the state. We return the variable assigned as it may
|
|
||||||
# be used to determine the final result.
|
|
||||||
return self._execute_assign(expression)
|
|
||||||
elif isinstance(expression, ast.Attribute):
|
|
||||||
value = self._execute_ast(expression.value)
|
|
||||||
return getattr(value, expression.attr)
|
|
||||||
elif isinstance(expression, ast.BinOp):
|
|
||||||
# Binary Operator -> return the result value
|
|
||||||
return self._execute_binop(expression)
|
|
||||||
elif isinstance(expression, ast.Call):
|
|
||||||
# Function call -> return the value of the function call
|
|
||||||
return self._execute_call(expression)
|
|
||||||
elif isinstance(expression, ast.Compare):
|
|
||||||
# Compare -> return True or False
|
|
||||||
return self._execute_condition(expression)
|
|
||||||
elif isinstance(expression, ast.Constant):
|
|
||||||
# Constant -> just return the value
|
|
||||||
return expression.value
|
|
||||||
elif isinstance(expression, ast.Dict):
|
|
||||||
# Dict -> evaluate all keys and values
|
|
||||||
result: Dict = {}
|
|
||||||
for k, v in zip(expression.keys, expression.values):
|
|
||||||
if k is not None:
|
|
||||||
result[self._execute_ast(k)] = self._execute_ast(v)
|
|
||||||
else:
|
|
||||||
result.update(self._execute_ast(v))
|
|
||||||
return result
|
|
||||||
elif isinstance(expression, ast.Expr):
|
|
||||||
# Expression -> evaluate the content
|
|
||||||
return self._execute_ast(expression.value)
|
|
||||||
elif isinstance(expression, ast.For):
|
|
||||||
return self._execute_for(expression)
|
|
||||||
elif isinstance(expression, ast.FormattedValue):
|
|
||||||
# Formatted value (part of f-string) -> evaluate the content
|
|
||||||
# and return
|
|
||||||
return self._execute_ast(expression.value)
|
|
||||||
elif isinstance(expression, ast.If):
|
|
||||||
# If -> execute the right branch
|
|
||||||
return self._execute_if(expression)
|
|
||||||
elif isinstance(expression, ast.Import):
|
|
||||||
# Import -> add imported names in self.state and return None.
|
|
||||||
self._execute_import(expression)
|
|
||||||
return None
|
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
|
||||||
self._execute_import_from(expression)
|
|
||||||
return None
|
|
||||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
|
||||||
# cannot pass type check
|
|
||||||
return self._execute_ast(expression.value)
|
|
||||||
elif isinstance(expression, ast.JoinedStr):
|
|
||||||
return "".join(
|
|
||||||
[str(self._execute_ast(v)) for v in expression.values]
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.List):
|
|
||||||
# List -> evaluate all elements
|
|
||||||
return [self._execute_ast(elt) for elt in expression.elts]
|
|
||||||
elif isinstance(expression, ast.Name):
|
|
||||||
# Name -> pick up the value in the state
|
|
||||||
return self._execute_name(expression)
|
|
||||||
elif isinstance(expression, ast.Subscript):
|
|
||||||
# Subscript -> return the value of the indexing
|
|
||||||
return self._execute_subscript(expression)
|
|
||||||
elif isinstance(expression, ast.Tuple):
|
|
||||||
return tuple([self._execute_ast(elt) for elt in expression.elts])
|
|
||||||
elif isinstance(expression, ast.UnaryOp):
|
|
||||||
# Binary Operator -> return the result value
|
|
||||||
return self._execute_unaryop(expression)
|
|
||||||
else:
|
|
||||||
# For now we refuse anything else. Let's add things as we need
|
|
||||||
# them.
|
|
||||||
raise InterpreterError(
|
|
||||||
f"{expression.__class__.__name__} is not supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_assign(self, assign: ast.Assign) -> Any:
|
|
||||||
targets = assign.targets
|
|
||||||
result = self._execute_ast(assign.value)
|
|
||||||
|
|
||||||
for target in targets:
|
|
||||||
self._assign(target, result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _assign(self, target: ast.expr, value: Any):
|
|
||||||
if isinstance(target, ast.Name):
|
|
||||||
self.state[target.id] = value
|
|
||||||
elif isinstance(target, ast.Tuple):
|
|
||||||
if not isinstance(value, tuple):
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Expected type tuple, but got"
|
|
||||||
f"{value.__class__.__name__} instead."
|
|
||||||
)
|
|
||||||
if len(target.elts) != len(value):
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Expected {len(target.elts)} values but got"
|
|
||||||
f" {len(value)}."
|
|
||||||
)
|
|
||||||
for t, v in zip(target.elts, value):
|
|
||||||
self.state[self._execute_ast(t)] = v
|
|
||||||
else:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Unsupported variable type. Expected "
|
|
||||||
f"ast.Name or ast.Tuple, got "
|
|
||||||
f"{target.__class__.__name__} instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_call(self, call: ast.Call) -> Any:
|
|
||||||
callable_func = self._execute_ast(call.func)
|
|
||||||
|
|
||||||
# Todo deal with args
|
|
||||||
args = [self._execute_ast(arg) for arg in call.args]
|
|
||||||
kwargs = {
|
|
||||||
keyword.arg: self._execute_ast(keyword.value)
|
|
||||||
for keyword in call.keywords
|
|
||||||
}
|
|
||||||
return callable_func(*args, **kwargs)
|
|
||||||
|
|
||||||
def _execute_subscript(self, subscript: ast.Subscript):
|
|
||||||
index = self._execute_ast(subscript.slice)
|
|
||||||
value = self._execute_ast(subscript.value)
|
|
||||||
if not isinstance(subscript.ctx, ast.Load):
|
|
||||||
raise InterpreterError(
|
|
||||||
f"{subscript.ctx.__class__.__name__} is not supported for "
|
|
||||||
"subscript."
|
|
||||||
)
|
|
||||||
if isinstance(value, (list, tuple)):
|
|
||||||
return value[int(index)]
|
|
||||||
if index in value:
|
|
||||||
return value[index]
|
|
||||||
if isinstance(index, str) and isinstance(value, dict):
|
|
||||||
close_matches = difflib.get_close_matches(
|
|
||||||
index,
|
|
||||||
[key for key in list(value.keys()) if isinstance(key, str)],
|
|
||||||
)
|
|
||||||
if len(close_matches) > 0:
|
|
||||||
return value[close_matches[0]]
|
|
||||||
|
|
||||||
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
|
||||||
|
|
||||||
def _execute_name(self, name: ast.Name):
|
|
||||||
if isinstance(name.ctx, ast.Store):
|
|
||||||
return name.id
|
|
||||||
elif isinstance(name.ctx, ast.Load):
|
|
||||||
return self._get_value_from_state(name.id)
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"{name.ctx} is not supported.")
|
|
||||||
|
|
||||||
def _execute_condition(self, condition: ast.Compare):
|
|
||||||
if len(condition.ops) > 1:
|
|
||||||
raise InterpreterError(
|
|
||||||
"Cannot evaluate conditions with multiple operators"
|
|
||||||
)
|
|
||||||
|
|
||||||
left = self._execute_ast(condition.left)
|
|
||||||
comparator = condition.ops[0]
|
|
||||||
right = self._execute_ast(condition.comparators[0])
|
|
||||||
|
|
||||||
if isinstance(comparator, ast.Eq):
|
|
||||||
return left == right
|
|
||||||
elif isinstance(comparator, ast.NotEq):
|
|
||||||
return left != right
|
|
||||||
elif isinstance(comparator, ast.Lt):
|
|
||||||
return left < right
|
|
||||||
elif isinstance(comparator, ast.LtE):
|
|
||||||
return left <= right
|
|
||||||
elif isinstance(comparator, ast.Gt):
|
|
||||||
return left > right
|
|
||||||
elif isinstance(comparator, ast.GtE):
|
|
||||||
return left >= right
|
|
||||||
elif isinstance(comparator, ast.Is):
|
|
||||||
return left is right
|
|
||||||
elif isinstance(comparator, ast.IsNot):
|
|
||||||
return left is not right
|
|
||||||
elif isinstance(comparator, ast.In):
|
|
||||||
return left in right
|
|
||||||
elif isinstance(comparator, ast.NotIn):
|
|
||||||
return left not in right
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"Unsupported operator: {comparator}")
|
|
||||||
|
|
||||||
def _execute_if(self, if_statement: ast.If):
|
|
||||||
result = None
|
|
||||||
if not isinstance(if_statement.test, ast.Compare):
|
|
||||||
raise InterpreterError(
|
|
||||||
"Only Campare expr supported in if statement, get"
|
|
||||||
f" {if_statement.test.__class__.__name__}"
|
|
||||||
)
|
|
||||||
if self._execute_condition(if_statement.test):
|
|
||||||
for line in if_statement.body:
|
|
||||||
line_result = self._execute_ast(line)
|
|
||||||
if line_result is not None:
|
|
||||||
result = line_result
|
|
||||||
else:
|
|
||||||
for line in if_statement.orelse:
|
|
||||||
line_result = self._execute_ast(line)
|
|
||||||
if line_result is not None:
|
|
||||||
result = line_result
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _execute_for(self, for_statement: ast.For):
|
|
||||||
result = None
|
|
||||||
for value in self._execute_ast(for_statement.iter):
|
|
||||||
self._assign(for_statement.target, value)
|
|
||||||
for line in for_statement.body:
|
|
||||||
line_result = self._execute_ast(line)
|
|
||||||
if line_result is not None:
|
|
||||||
result = line_result
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _execute_import(self, import_module: ast.Import) -> None:
|
|
||||||
for module in import_module.names:
|
|
||||||
self._validate_import(module.name)
|
|
||||||
alias = module.asname or module.name
|
|
||||||
self.state[alias] = importlib.import_module(module.name)
|
|
||||||
|
|
||||||
def _execute_import_from(self, import_from: ast.ImportFrom):
|
|
||||||
if import_from.module is None:
|
|
||||||
raise InterpreterError("\"from . import\" is not supported.")
|
|
||||||
for import_name in import_from.names:
|
|
||||||
full_name = import_from.module + f".{import_name.name}"
|
|
||||||
self._validate_import(full_name)
|
|
||||||
imported_module = importlib.import_module(import_from.module)
|
|
||||||
alias = import_name.asname or import_name.name
|
|
||||||
self.state[alias] = getattr(imported_module, import_name.name)
|
|
||||||
|
|
||||||
def _validate_import(self, full_name: str):
|
|
||||||
tmp_name = ""
|
|
||||||
found_name = False
|
|
||||||
for name in full_name.split("."):
|
|
||||||
tmp_name += name if tmp_name == "" else f".{name}"
|
|
||||||
if tmp_name in self.import_white_list:
|
|
||||||
found_name = True
|
|
||||||
return
|
|
||||||
|
|
||||||
if not found_name:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"It is not permitted to import modules "
|
|
||||||
f"than module white list (try to import "
|
|
||||||
f"{full_name})."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_binop(self, binop: ast.BinOp):
|
|
||||||
left = self._execute_ast(binop.left)
|
|
||||||
operator = binop.op
|
|
||||||
right = self._execute_ast(binop.right)
|
|
||||||
|
|
||||||
if isinstance(operator, ast.Add):
|
|
||||||
return left + right
|
|
||||||
elif isinstance(operator, ast.Sub):
|
|
||||||
return left - right
|
|
||||||
elif isinstance(operator, ast.Mult):
|
|
||||||
return left * right
|
|
||||||
elif isinstance(operator, ast.Div):
|
|
||||||
return left / right
|
|
||||||
elif isinstance(operator, ast.FloorDiv):
|
|
||||||
return left // right
|
|
||||||
elif isinstance(operator, ast.Mod):
|
|
||||||
return left % right
|
|
||||||
elif isinstance(operator, ast.Pow):
|
|
||||||
return left**right
|
|
||||||
elif isinstance(operator, ast.LShift):
|
|
||||||
return left << right
|
|
||||||
elif isinstance(operator, ast.RShift):
|
|
||||||
return left >> right
|
|
||||||
elif isinstance(operator, ast.MatMult):
|
|
||||||
return left @ right
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"Operator not supported: {operator}")
|
|
||||||
|
|
||||||
def _execute_unaryop(self, unaryop: ast.UnaryOp):
|
|
||||||
operand = self._execute_ast(unaryop.operand)
|
|
||||||
operator = unaryop.op
|
|
||||||
|
|
||||||
if isinstance(operator, ast.UAdd):
|
|
||||||
return +operand
|
|
||||||
elif isinstance(operator, ast.USub):
|
|
||||||
return -operand
|
|
||||||
elif isinstance(operator, ast.Not):
|
|
||||||
return not operand
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"Operator not supported: {operator}")
|
|
||||||
|
|
||||||
def _get_value_from_state(self, key: str) -> Any:
|
|
||||||
if key in self.state:
|
|
||||||
return self.state[key]
|
|
||||||
else:
|
|
||||||
close_matches = difflib.get_close_matches(
|
|
||||||
key, list(self.fuzz_state.keys()), n=1
|
|
||||||
)
|
|
||||||
if close_matches:
|
|
||||||
return self.fuzz_state[close_matches[0]]
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"The variable `{key}` is not defined.")
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
# TODO: Do we need a file to store this error class?
|
|
||||||
class InterpreterError(Exception):
|
|
||||||
r"""Exception raised for errors that can be solved by regenerating code"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from camel.interpreters.base import BaseInterpreter
|
|
||||||
from camel.interpreters.interpreter_error import InterpreterError
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from jupyter_client import BlockingKernelClient, KernelManager
|
|
||||||
|
|
||||||
TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
class JupyterKernelInterpreter(BaseInterpreter):
|
|
||||||
r"""A class for executing code strings in a Jupyter Kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_confirm (bool, optional): If `True`, prompt user before
|
|
||||||
running code strings for security. Defaults to `True`.
|
|
||||||
print_stdout (bool, optional): If `True`, print the standard
|
|
||||||
output of the executed code. Defaults to `False`.
|
|
||||||
print_stderr (bool, optional): If `True`, print the standard error
|
|
||||||
of the executed code. Defaults to `True`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
require_confirm: bool = True,
|
|
||||||
print_stdout: bool = False,
|
|
||||||
print_stderr: bool = True,
|
|
||||||
) -> None:
|
|
||||||
self.require_confirm = require_confirm
|
|
||||||
self.print_stdout = print_stdout
|
|
||||||
self.print_stderr = print_stderr
|
|
||||||
|
|
||||||
self.kernel_manager: Optional[KernelManager] = None
|
|
||||||
self.client: Optional[BlockingKernelClient] = None
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
r"""Clean up the kernel and client."""
|
|
||||||
|
|
||||||
if self.kernel_manager:
|
|
||||||
self.kernel_manager.shutdown_kernel()
|
|
||||||
if self.client:
|
|
||||||
self.client.stop_channels()
|
|
||||||
|
|
||||||
def _initialize_if_needed(self) -> None:
|
|
||||||
r"""Initialize the kernel manager and client if they are not already
|
|
||||||
initialized.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.kernel_manager is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
from jupyter_client.manager import start_new_kernel
|
|
||||||
|
|
||||||
self.kernel_manager, self.client = start_new_kernel()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _clean_ipython_output(output: str) -> str:
|
|
||||||
r"""Remove ANSI escape sequences from the output."""
|
|
||||||
|
|
||||||
ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
|
|
||||||
return ansi_escape.sub('', output)
|
|
||||||
|
|
||||||
def _execute(self, code: str, timeout: float) -> str:
|
|
||||||
r"""Execute the code in the Jupyter kernel and return the result."""
|
|
||||||
|
|
||||||
if not self.kernel_manager or not self.client:
|
|
||||||
raise InterpreterError("Jupyter client is not initialized.")
|
|
||||||
|
|
||||||
self.client.execute(code)
|
|
||||||
outputs = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
msg = self.client.get_iopub_msg(timeout=timeout)
|
|
||||||
msg_content = msg["content"]
|
|
||||||
msg_type = msg.get("msg_type", None)
|
|
||||||
|
|
||||||
if msg_content.get("execution_state", None) == "idle":
|
|
||||||
break
|
|
||||||
|
|
||||||
if msg_type == "error":
|
|
||||||
print(msg_content.keys())
|
|
||||||
print(msg_content)
|
|
||||||
traceback = "\n".join(msg_content["traceback"])
|
|
||||||
outputs.append(traceback)
|
|
||||||
elif msg_type == "stream":
|
|
||||||
outputs.append(msg_content["text"])
|
|
||||||
elif msg_type in ["execute_result", "display_data"]:
|
|
||||||
outputs.append(msg_content["data"]["text/plain"])
|
|
||||||
if "image/png" in msg_content["data"]:
|
|
||||||
outputs.append(
|
|
||||||
f"\n\n"
|
|
||||||
)
|
|
||||||
except queue.Empty:
|
|
||||||
outputs.append("Time out")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
outputs.append(f"Exception occurred: {e!s}")
|
|
||||||
break
|
|
||||||
|
|
||||||
exec_result = "\n".join(outputs)
|
|
||||||
return self._clean_ipython_output(exec_result)
|
|
||||||
|
|
||||||
def run(self, code: str, code_type: str) -> str:
|
|
||||||
r"""Executes the given code in the Jupyter kernel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The code string to execute.
|
|
||||||
code_type (str): The type of code to execute (e.g., 'python',
|
|
||||||
'bash').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string containing the captured result of the
|
|
||||||
executed code.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InterpreterError: If there is an error when doing code execution.
|
|
||||||
"""
|
|
||||||
self._initialize_if_needed()
|
|
||||||
|
|
||||||
if code_type == "bash":
|
|
||||||
code = f"%%bash\n({code})"
|
|
||||||
try:
|
|
||||||
result = self._execute(code, timeout=TIMEOUT)
|
|
||||||
except Exception as e:
|
|
||||||
raise InterpreterError(f"Execution failed: {e!s}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def supported_code_types(self) -> List[str]:
|
|
||||||
r"""Provides supported code types by the interpreter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: Supported code types.
|
|
||||||
"""
|
|
||||||
return ["python", "bash"]
|
|
||||||
|
|
||||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
|
||||||
r"""Updates the action space for the interpreter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_space (Dict[str, Any]): A dictionary representing the
|
|
||||||
new or updated action space.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: Always raised because `JupyterKernelInterpreter`
|
|
||||||
does not support updating the action space.
|
|
||||||
"""
|
|
||||||
raise RuntimeError(
|
|
||||||
"SubprocessInterpreter doesn't support " "`action_space`."
|
|
||||||
)
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# You may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import shlex
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, ClassVar, Dict, List
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from camel.interpreters.base import BaseInterpreter
|
|
||||||
from camel.interpreters.interpreter_error import InterpreterError
|
|
||||||
from camel.logger import get_logger
|
|
||||||
import os
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocessInterpreter(BaseInterpreter):
|
|
||||||
r"""SubprocessInterpreter is a class for executing code files or code
|
|
||||||
strings in a subprocess.
|
|
||||||
|
|
||||||
This class handles the execution of code in different scripting languages
|
|
||||||
(currently Python, Bash, and Node.js) within a subprocess, capturing their
|
|
||||||
stdout and stderr streams, and allowing user checking before executing code
|
|
||||||
strings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
require_confirm (bool, optional): If True, prompt user before running
|
|
||||||
code strings for security. (default: :obj:`True`)
|
|
||||||
print_stdout (bool, optional): If True, print the standard output of
|
|
||||||
the executed code. (default: :obj:`False`)
|
|
||||||
print_stderr (bool, optional): If True, print the standard error of the
|
|
||||||
executed code. (default: :obj:`True`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "python {file_name}",
|
|
||||||
"bash": "bash {file_name}",
|
|
||||||
"node": "node {file_name}",
|
|
||||||
}
|
|
||||||
|
|
||||||
_CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "py",
|
|
||||||
"bash": "sh",
|
|
||||||
"node": "js",
|
|
||||||
}
|
|
||||||
|
|
||||||
_CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
|
|
||||||
"python": "python",
|
|
||||||
"py3": "python",
|
|
||||||
"python3": "python",
|
|
||||||
"py": "python",
|
|
||||||
"shell": "bash",
|
|
||||||
"bash": "bash",
|
|
||||||
"sh": "bash",
|
|
||||||
"node": "node",
|
|
||||||
"javascript": "node",
|
|
||||||
"js": "node",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
require_confirm: bool = True,
|
|
||||||
print_stdout: bool = False,
|
|
||||||
print_stderr: bool = True,
|
|
||||||
) -> None:
|
|
||||||
self.require_confirm = require_confirm
|
|
||||||
self.print_stdout = print_stdout
|
|
||||||
self.print_stderr = print_stderr
|
|
||||||
|
|
||||||
def run_file(
|
|
||||||
self,
|
|
||||||
file: Path,
|
|
||||||
code_type: str,
|
|
||||||
) -> str:
|
|
||||||
r"""Executes a code file in a subprocess and captures its output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (Path): The path object of the file to run.
|
|
||||||
code_type (str): The type of code to execute (e.g., 'python',
|
|
||||||
'bash', 'node').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string containing the captured stdout and stderr of the
|
|
||||||
executed code.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the provided file path does not point to a file.
|
|
||||||
InterpreterError: If the code type provided is not supported.
|
|
||||||
"""
|
|
||||||
if not file.is_file():
|
|
||||||
raise RuntimeError(f"{file} is not a file.")
|
|
||||||
code_type = self._check_code_type(code_type)
|
|
||||||
cmd = shlex.split(
|
|
||||||
self._CODE_EXECUTE_CMD_MAPPING[code_type].format(file_name=str(file))
|
|
||||||
)
|
|
||||||
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
|
||||||
)
|
|
||||||
stdout, stderr = proc.communicate()
|
|
||||||
if self.print_stdout and stdout:
|
|
||||||
print("======stdout======")
|
|
||||||
print(Fore.GREEN + stdout + Fore.RESET)
|
|
||||||
print("==================")
|
|
||||||
if self.print_stderr and stderr:
|
|
||||||
print("======stderr======")
|
|
||||||
print(Fore.RED + stderr + Fore.RESET)
|
|
||||||
print("==================")
|
|
||||||
exec_result = f"{stdout}"
|
|
||||||
exec_result += f"(stderr: {stderr})" if stderr else ""
|
|
||||||
return exec_result
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
code_type: str,
|
|
||||||
) -> str:
|
|
||||||
r"""Generates a temporary file with the given code, executes it, and
|
|
||||||
deletes the file afterward.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): The code string to execute.
|
|
||||||
code_type (str): The type of code to execute (e.g., 'python',
|
|
||||||
'bash', 'node').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string containing the captured stdout and stderr of the
|
|
||||||
executed code.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InterpreterError: If the user declines to run the code or if the
|
|
||||||
code type is unsupported.
|
|
||||||
"""
|
|
||||||
code_type = self._check_code_type(code_type)
|
|
||||||
|
|
||||||
if self.require_confirm:
|
|
||||||
logger.info(
|
|
||||||
f"The following {code_type} code will run on your " "computer: {code}"
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
choice = input("Running code? [Y/n]:").lower()
|
|
||||||
if choice in ["y", "yes", "ye", ""]:
|
|
||||||
break
|
|
||||||
elif choice in ["no", "n"]:
|
|
||||||
raise InterpreterError(
|
|
||||||
"Execution halted: User opted not to run the code. "
|
|
||||||
"This choice stops the current operation and any "
|
|
||||||
"further code execution."
|
|
||||||
)
|
|
||||||
|
|
||||||
temp_file_path = self._create_temp_file(
|
|
||||||
code=code, extension=self._CODE_EXTENSION_MAPPING[code_type]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.run_file(temp_file_path, code_type)
|
|
||||||
|
|
||||||
temp_file_path.unlink()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _create_temp_file(self, code: str, extension: str) -> Path:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", delete=False, suffix=f".{extension}"
|
|
||||||
) as f:
|
|
||||||
f.write(code)
|
|
||||||
name = f.name
|
|
||||||
return Path(name)
|
|
||||||
|
|
||||||
# def _create_temp_file(self, code: str, extension: str) -> Path:
|
|
||||||
# # generate a random file name
|
|
||||||
# import datetime
|
|
||||||
|
|
||||||
# current_time = datetime.datetime.now().strftime("%d%H%M%S")
|
|
||||||
|
|
||||||
# temp_file_path = os.path.join("tmp", f"{current_time}.{extension}")
|
|
||||||
# with open(temp_file_path, "w", encoding='utf-8') as f:
|
|
||||||
# f.write(code)
|
|
||||||
# f.close()
|
|
||||||
# f.flush()
|
|
||||||
# breakpoint()
|
|
||||||
# return Path(temp_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_code_type(self, code_type: str) -> str:
|
|
||||||
if code_type not in self._CODE_TYPE_MAPPING:
|
|
||||||
raise InterpreterError(
|
|
||||||
f"Unsupported code type {code_type}. Currently "
|
|
||||||
f"`{self.__class__.__name__}` only supports "
|
|
||||||
f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
|
|
||||||
)
|
|
||||||
return self._CODE_TYPE_MAPPING[code_type]
|
|
||||||
|
|
||||||
def supported_code_types(self) -> List[str]:
|
|
||||||
r"""Provides supported code types by the interpreter."""
|
|
||||||
return list(self._CODE_EXTENSION_MAPPING.keys())
|
|
||||||
|
|
||||||
def update_action_space(self, action_space: Dict[str, Any]) -> None:
|
|
||||||
r"""Updates action space for *python* interpreter"""
|
|
||||||
raise RuntimeError("SubprocessInterpreter doesn't support " "`action_space`.")
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .apify_reader import Apify
|
|
||||||
from .base_io import File
|
|
||||||
from .chunkr_reader import ChunkrReader
|
|
||||||
from .firecrawl_reader import Firecrawl
|
|
||||||
from .jina_url_reader import JinaURLReader
|
|
||||||
from .unstructured_io import UnstructuredIO
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'File',
|
|
||||||
'UnstructuredIO',
|
|
||||||
'JinaURLReader',
|
|
||||||
'Firecrawl',
|
|
||||||
'Apify',
|
|
||||||
'ChunkrReader',
|
|
||||||
]
|
|
||||||
@@ -1,223 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from apify_client.clients import DatasetClient
|
|
||||||
|
|
||||||
from camel.utils import api_keys_required
|
|
||||||
|
|
||||||
|
|
||||||
class Apify:
|
|
||||||
r"""Apify is a platform that allows you to automate any web workflow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (Optional[str]): API key for authenticating with the Apify API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_keys_required("APIFY_API_KEY")
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
from apify_client import ApifyClient
|
|
||||||
|
|
||||||
self._api_key = api_key or os.environ.get("APIFY_API_KEY")
|
|
||||||
self.client = ApifyClient(token=self._api_key)
|
|
||||||
|
|
||||||
def run_actor(
|
|
||||||
self,
|
|
||||||
actor_id: str,
|
|
||||||
run_input: Optional[dict] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
build: Optional[str] = None,
|
|
||||||
max_items: Optional[int] = None,
|
|
||||||
memory_mbytes: Optional[int] = None,
|
|
||||||
timeout_secs: Optional[int] = None,
|
|
||||||
webhooks: Optional[list] = None,
|
|
||||||
wait_secs: Optional[int] = None,
|
|
||||||
) -> Optional[dict]:
|
|
||||||
r"""Run an actor on the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actor_id (str): The ID of the actor to run.
|
|
||||||
run_input (Optional[dict]): The input data for the actor. Defaults
|
|
||||||
to `None`.
|
|
||||||
content_type (str, optional): The content type of the input.
|
|
||||||
build (str, optional): Specifies the Actor build to run. It can be
|
|
||||||
either a build tag or build number. By default, the run uses
|
|
||||||
the build specified in the default run configuration for the
|
|
||||||
Actor (typically latest).
|
|
||||||
max_items (int, optional): Maximum number of results that will be
|
|
||||||
returned by this run. If the Actor is charged per result, you
|
|
||||||
will not be charged for more results than the given limit.
|
|
||||||
memory_mbytes (int, optional): Memory limit for the run, in
|
|
||||||
megabytes. By default, the run uses a memory limit specified in
|
|
||||||
the default run configuration for the Actor.
|
|
||||||
timeout_secs (int, optional): Optional timeout for the run, in
|
|
||||||
seconds. By default, the run uses timeout specified in the
|
|
||||||
default run configuration for the Actor.
|
|
||||||
webhooks (list, optional): Optional webhooks
|
|
||||||
(https://docs.apify.com/webhooks) associated with the Actor
|
|
||||||
run, which can be used to receive a notification, e.g. when the
|
|
||||||
Actor finished or failed. If you already have a webhook set up
|
|
||||||
for the Actor, you do not have to add it again here.
|
|
||||||
wait_secs (int, optional): The maximum number of seconds the server
|
|
||||||
waits for finish. If not provided, waits indefinitely.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[dict]: The output data from the actor if successful.
|
|
||||||
# please use the 'defaultDatasetId' to get the dataset
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the actor fails to run.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.client.actor(actor_id).call(
|
|
||||||
run_input=run_input,
|
|
||||||
content_type=content_type,
|
|
||||||
build=build,
|
|
||||||
max_items=max_items,
|
|
||||||
memory_mbytes=memory_mbytes,
|
|
||||||
timeout_secs=timeout_secs,
|
|
||||||
webhooks=webhooks,
|
|
||||||
wait_secs=wait_secs,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to run actor {actor_id}: {e}") from e
|
|
||||||
|
|
||||||
def get_dataset_client(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> "DatasetClient":
|
|
||||||
r"""Get a dataset client from the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_id (str): The ID of the dataset to get the client for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DatasetClient: The dataset client.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the dataset client fails to be retrieved.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.client.dataset(dataset_id)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to get dataset {dataset_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def get_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> Optional[dict]:
|
|
||||||
r"""Get a dataset from the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_id (str): The ID of the dataset to get.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The dataset.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the dataset fails to be retrieved.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.get_dataset_client(dataset_id).get()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to get dataset {dataset_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def update_dataset(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
name: str,
|
|
||||||
) -> dict:
|
|
||||||
r"""Update a dataset on the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_id (str): The ID of the dataset to update.
|
|
||||||
name (str): The new name for the dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The updated dataset.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the dataset fails to be updated.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.get_dataset_client(dataset_id).update(name=name)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to update dataset {dataset_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def get_dataset_items(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> List:
|
|
||||||
r"""Get items from a dataset on the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_id (str): The ID of the dataset to get items from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The items in the dataset.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the items fail to be retrieved.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
items = self.get_dataset_client(dataset_id).list_items().items
|
|
||||||
return items
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to get dataset items {dataset_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def get_datasets(
|
|
||||||
self,
|
|
||||||
unnamed: Optional[bool] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: Optional[int] = None,
|
|
||||||
desc: Optional[bool] = None,
|
|
||||||
) -> List[dict]:
|
|
||||||
r"""Get all named datasets from the Apify platform.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unnamed (bool, optional): Whether to include unnamed key-value
|
|
||||||
stores in the list
|
|
||||||
limit (int, optional): How many key-value stores to retrieve
|
|
||||||
offset (int, optional): What key-value store to include as first
|
|
||||||
when retrieving the list
|
|
||||||
desc (bool, optional): Whether to sort the key-value stores in
|
|
||||||
descending order based on their modification date
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: The datasets.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the datasets fail to be retrieved.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return (
|
|
||||||
self.client.datasets()
|
|
||||||
.list(unnamed=unnamed, limit=limit, offset=offset, desc=desc)
|
|
||||||
.items
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to get datasets: {e}") from e
|
|
||||||
@@ -1,328 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from copy import deepcopy
|
|
||||||
from hashlib import md5
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from camel.utils import dependencies_required
|
|
||||||
|
|
||||||
|
|
||||||
class File(ABC):
|
|
||||||
r"""Represents an uploaded file comprised of Documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the file.
|
|
||||||
file_id (str): The unique identifier of the file.
|
|
||||||
metadata (Dict[str, Any], optional): Additional metadata
|
|
||||||
associated with the file. Defaults to None.
|
|
||||||
docs (List[Dict[str, Any]], optional): A list of documents
|
|
||||||
contained within the file. Defaults to None.
|
|
||||||
raw_bytes (bytes, optional): The raw bytes content of the file.
|
|
||||||
Defaults to b"".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
file_id: str,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
docs: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
raw_bytes: bytes = b"",
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.file_id = file_id
|
|
||||||
self.metadata = metadata or {}
|
|
||||||
self.docs = docs or []
|
|
||||||
self.raw_bytes = raw_bytes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "File":
|
|
||||||
r"""Creates a File object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File: A File object.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_raw_bytes(cls, raw_bytes: bytes, filename: str) -> "File":
|
|
||||||
r"""Creates a File object from raw bytes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_bytes (bytes): The raw bytes content of the file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File: A File object.
|
|
||||||
"""
|
|
||||||
file = BytesIO(raw_bytes)
|
|
||||||
return cls.from_bytes(file, filename)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file(file: BytesIO, filename: str) -> "File":
|
|
||||||
r"""Reads an uploaded file and returns a File object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File: A File object.
|
|
||||||
"""
|
|
||||||
ext_to_cls = {
|
|
||||||
"docx": DocxFile,
|
|
||||||
"pdf": PdfFile,
|
|
||||||
"txt": TxtFile,
|
|
||||||
"json": JsonFile,
|
|
||||||
"html": HtmlFile,
|
|
||||||
}
|
|
||||||
|
|
||||||
ext = filename.split(".")[-1].lower()
|
|
||||||
if ext not in ext_to_cls:
|
|
||||||
raise NotImplementedError(f"File type {ext} not supported")
|
|
||||||
|
|
||||||
out_file = ext_to_cls[ext].from_bytes(file, filename)
|
|
||||||
return out_file
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_from_raw_bytes(raw_bytes: bytes, filename: str) -> "File":
|
|
||||||
r"""Reads raw bytes and returns a File object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_bytes (bytes): The raw bytes content of the file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File: A File object.
|
|
||||||
"""
|
|
||||||
file = BytesIO(raw_bytes)
|
|
||||||
return File.create_file(file, filename)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"File(name={self.name}, id={self.file_id}, "
|
|
||||||
f"metadata={self.metadata}, docs={self.docs})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"File(name={self.name}, id={self.file_id}, metadata="
|
|
||||||
f"{self.metadata})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def copy(self) -> "File":
|
|
||||||
r"""Create a deep copy of this File"""
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
name=self.name,
|
|
||||||
file_id=self.file_id,
|
|
||||||
metadata=deepcopy(self.metadata),
|
|
||||||
docs=deepcopy(self.docs),
|
|
||||||
raw_bytes=self.raw_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def strip_consecutive_newlines(text: str) -> str:
|
|
||||||
r"""Strips consecutive newlines from a string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The string to strip.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The string with consecutive newlines stripped.
|
|
||||||
"""
|
|
||||||
return re.sub(r"\s*\n\s*", "\n", text)
|
|
||||||
|
|
||||||
|
|
||||||
class DocxFile(File):
|
|
||||||
@classmethod
|
|
||||||
@dependencies_required('docx2txt')
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "DocxFile":
|
|
||||||
r"""Creates a DocxFile object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
docx file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DocxFile: A DocxFile object.
|
|
||||||
"""
|
|
||||||
import docx2txt
|
|
||||||
|
|
||||||
text = docx2txt.process(file)
|
|
||||||
text = strip_consecutive_newlines(text)
|
|
||||||
# Create a dictionary with the extracted text
|
|
||||||
doc = {"page_content": text.strip()}
|
|
||||||
# Calculate a unique identifier for the file
|
|
||||||
file_id = md5(file.getvalue()).hexdigest()
|
|
||||||
# Reset the file pointer to the beginning
|
|
||||||
file.seek(0)
|
|
||||||
return cls(
|
|
||||||
name=filename,
|
|
||||||
file_id=file_id,
|
|
||||||
docs=[doc],
|
|
||||||
raw_bytes=file.getvalue(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PdfFile(File):
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "PdfFile":
|
|
||||||
r"""Creates a PdfFile object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
pdf file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PdfFile: A PdfFile object.
|
|
||||||
"""
|
|
||||||
# Use fitz to extract text from pdf files
|
|
||||||
try:
|
|
||||||
import fitz
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install `PyMuPDF` first. "
|
|
||||||
"You can install it by running "
|
|
||||||
"`pip install PyMuPDF`."
|
|
||||||
)
|
|
||||||
pdf = fitz.open(stream=file.read(), filetype="pdf")
|
|
||||||
docs = []
|
|
||||||
for i, page in enumerate(pdf):
|
|
||||||
text = page.get_text(sort=True)
|
|
||||||
text = strip_consecutive_newlines(text)
|
|
||||||
# Create a dictionary with the extracted text
|
|
||||||
doc = {"page_content": text.strip(), "page": i + 1}
|
|
||||||
docs.append(doc)
|
|
||||||
# Calculate a unique identifier for the file
|
|
||||||
file_id = md5(file.getvalue()).hexdigest()
|
|
||||||
# Reset the file pointer to the beginning
|
|
||||||
file.seek(0)
|
|
||||||
return cls(
|
|
||||||
name=filename,
|
|
||||||
file_id=file_id,
|
|
||||||
docs=docs,
|
|
||||||
raw_bytes=file.getvalue(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TxtFile(File):
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "TxtFile":
|
|
||||||
r"""Creates a TxtFile object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
txt file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TxtFile: A TxtFile object.
|
|
||||||
"""
|
|
||||||
# Read the text from the file
|
|
||||||
text = file.read().decode("utf-8")
|
|
||||||
text = strip_consecutive_newlines(text)
|
|
||||||
# Create a dictionary with the extracted text
|
|
||||||
doc = {"page_content": text.strip()}
|
|
||||||
# Calculate a unique identifier for the file
|
|
||||||
file_id = md5(file.getvalue()).hexdigest()
|
|
||||||
# Reset the file pointer to the beginning
|
|
||||||
file.seek(0)
|
|
||||||
return cls(
|
|
||||||
name=filename,
|
|
||||||
file_id=file_id,
|
|
||||||
docs=[doc],
|
|
||||||
raw_bytes=file.getvalue(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class JsonFile(File):
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "JsonFile":
|
|
||||||
r"""Creates a JsonFile object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
json file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JsonFile: A JsonFile object.
|
|
||||||
"""
|
|
||||||
# Parse the JSON data from the file
|
|
||||||
data = json.load(file)
|
|
||||||
# Create a dictionary with the parsed data
|
|
||||||
doc = {"page_content": json.dumps(data)}
|
|
||||||
# Calculate a unique identifier for the file
|
|
||||||
file_id = md5(file.getvalue()).hexdigest()
|
|
||||||
# Reset the file pointer to the beginning
|
|
||||||
file.seek(0)
|
|
||||||
return cls(
|
|
||||||
name=filename,
|
|
||||||
file_id=file_id,
|
|
||||||
docs=[doc],
|
|
||||||
raw_bytes=file.getvalue(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HtmlFile(File):
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, file: BytesIO, filename: str) -> "HtmlFile":
|
|
||||||
r"""Creates a HtmlFile object from a BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (BytesIO): A BytesIO object representing the contents of the
|
|
||||||
html file.
|
|
||||||
filename (str): The name of the file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HtmlFile: A HtmlFile object.
|
|
||||||
"""
|
|
||||||
# Parse the HTML data from the file
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install `beautifulsoup4` first. "
|
|
||||||
"You can install it by running "
|
|
||||||
"`pip install beautifulsoup4`."
|
|
||||||
)
|
|
||||||
soup = BeautifulSoup(file, "html.parser")
|
|
||||||
text = soup.get_text()
|
|
||||||
text = strip_consecutive_newlines(text)
|
|
||||||
# Create a dictionary with the parsed data
|
|
||||||
doc = {"page_content": text.strip()}
|
|
||||||
# Calculate a unique identifier for the file
|
|
||||||
file_id = md5(file.getvalue()).hexdigest()
|
|
||||||
# Reset the file pointer to the beginning
|
|
||||||
file.seek(0)
|
|
||||||
return cls(
|
|
||||||
name=filename,
|
|
||||||
file_id=file_id,
|
|
||||||
docs=[doc],
|
|
||||||
raw_bytes=file.getvalue(),
|
|
||||||
)
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import IO, Any, Optional, Union
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from camel.utils import api_keys_required
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkrReader:
|
|
||||||
r"""Chunkr Reader for processing documents and returning content
|
|
||||||
in various formats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (Optional[str], optional): The API key for Chunkr API. If not
|
|
||||||
provided, it will be retrieved from the environment variable
|
|
||||||
`CHUNKR_API_KEY`. (default: :obj:`None`)
|
|
||||||
url (Optional[str], optional): The url to the Chunkr service.
|
|
||||||
(default: :obj:`https://api.chunkr.ai/api/v1/task`)
|
|
||||||
timeout (int, optional): The maximum time in seconds to wait for the
|
|
||||||
API responses. (default: :obj:`30`)
|
|
||||||
**kwargs (Any): Additional keyword arguments for request headers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = "https://api.chunkr.ai/api/v1/task",
|
|
||||||
timeout: int = 30,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._api_key = api_key or os.getenv('CHUNKR_API_KEY')
|
|
||||||
self._url = os.getenv('CHUNKR_API_URL') or url
|
|
||||||
self._headers = {
|
|
||||||
"Authorization": f"{self._api_key}",
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
def submit_task(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
model: str = "Fast",
|
|
||||||
ocr_strategy: str = "Auto",
|
|
||||||
target_chunk_length: str = "512",
|
|
||||||
) -> str:
|
|
||||||
r"""Submits a file to the Chunkr API and returns the task ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): The path to the file to be uploaded.
|
|
||||||
model (str, optional): The model to be used for the task.
|
|
||||||
(default: :obj:`Fast`)
|
|
||||||
ocr_strategy (str, optional): The OCR strategy. Defaults to 'Auto'.
|
|
||||||
target_chunk_length (str, optional): The target chunk length.
|
|
||||||
(default: :obj:`512`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The task ID.
|
|
||||||
"""
|
|
||||||
with open(file_path, 'rb') as file:
|
|
||||||
files: dict[
|
|
||||||
str, Union[tuple[None, IO[bytes]], tuple[None, str]]
|
|
||||||
] = {
|
|
||||||
'file': (
|
|
||||||
None,
|
|
||||||
file,
|
|
||||||
), # Properly pass the file as a binary stream
|
|
||||||
'model': (None, model),
|
|
||||||
'ocr_strategy': (None, ocr_strategy),
|
|
||||||
'target_chunk_length': (None, target_chunk_length),
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
self._url, # type: ignore[arg-type]
|
|
||||||
headers=self._headers,
|
|
||||||
files=files,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
task_id = response.json().get('task_id')
|
|
||||||
if not task_id:
|
|
||||||
raise ValueError("Task ID not returned in the response.")
|
|
||||||
logger.info(f"Task submitted successfully. Task ID: {task_id}")
|
|
||||||
return task_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to submit task: {e}")
|
|
||||||
raise ValueError(f"Failed to submit task: {e}") from e
|
|
||||||
|
|
||||||
def get_task_output(self, task_id: str, max_retries: int = 5) -> str:
|
|
||||||
r"""Polls the Chunkr API to check the task status and returns the task
|
|
||||||
result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id (str): The task ID to check the status for.
|
|
||||||
max_retries (int, optional): Maximum number of retry attempts.
|
|
||||||
(default: :obj:`5`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The formatted task result in JSON format.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the task status cannot be retrieved.
|
|
||||||
RuntimeError: If the maximum number of retries is reached without
|
|
||||||
a successful task completion.
|
|
||||||
"""
|
|
||||||
url_get = f"{self._url}/{task_id}"
|
|
||||||
attempts = 0
|
|
||||||
|
|
||||||
while attempts < max_retries:
|
|
||||||
try:
|
|
||||||
response = requests.get(
|
|
||||||
url_get, headers=self._headers, timeout=self.timeout
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
task_status = response.json().get('status')
|
|
||||||
|
|
||||||
if task_status == "Succeeded":
|
|
||||||
logger.info(f"Task {task_id} completed successfully.")
|
|
||||||
return self._pretty_print_response(response.json())
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Task {task_id} is still {task_status}. Retrying "
|
|
||||||
"in 5 seconds..."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to retrieve task status: {e}")
|
|
||||||
raise ValueError(f"Failed to retrieve task status: {e}") from e
|
|
||||||
|
|
||||||
attempts += 1
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
logger.error(f"Max retries reached for task {task_id}.")
|
|
||||||
raise RuntimeError(f"Max retries reached for task {task_id}.")
|
|
||||||
|
|
||||||
def _pretty_print_response(self, response_json: dict) -> str:
|
|
||||||
r"""Pretty prints the JSON response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_json (dict): The response JSON to pretty print.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Formatted JSON as a string.
|
|
||||||
"""
|
|
||||||
return json.dumps(response_json, indent=4)
|
|
||||||
@@ -1,202 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class Firecrawl:
|
|
||||||
r"""Firecrawl allows you to turn entire websites into LLM-ready markdown.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (Optional[str]): API key for authenticating with the Firecrawl
|
|
||||||
API.
|
|
||||||
api_url (Optional[str]): Base URL for the Firecrawl API.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://docs.firecrawl.dev/introduction
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_url: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
from firecrawl import FirecrawlApp
|
|
||||||
|
|
||||||
self._api_key = api_key or os.environ.get("FIRECRAWL_API_KEY")
|
|
||||||
self._api_url = api_url or os.environ.get("FIRECRAWL_API_URL")
|
|
||||||
|
|
||||||
self.app = FirecrawlApp(api_key=self._api_key, api_url=self._api_url)
|
|
||||||
|
|
||||||
def crawl(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
r"""Crawl a URL and all accessible subpages. Customize the crawl by
|
|
||||||
setting different parameters, and receive the full response or a job
|
|
||||||
ID based on the specified options.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to crawl.
|
|
||||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
|
||||||
crawl request. Defaults to `None`.
|
|
||||||
**kwargs (Any): Additional keyword arguments, such as
|
|
||||||
`poll_interval`, `idempotency_key`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The crawl job ID or the crawl results if waiting until
|
|
||||||
completion.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the crawling process fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
crawl_response = self.app.crawl_url(
|
|
||||||
url=url,
|
|
||||||
params=params,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return crawl_response
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to crawl the URL: {e}")
|
|
||||||
|
|
||||||
def markdown_crawl(self, url: str) -> str:
|
|
||||||
r"""Crawl a URL and all accessible subpages and return the content in
|
|
||||||
Markdown format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to crawl.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The content of the URL in Markdown format.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the crawling process fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
crawl_result = self.app.crawl_url(
|
|
||||||
url,
|
|
||||||
{'formats': ['markdown']},
|
|
||||||
)
|
|
||||||
if not isinstance(crawl_result, list):
|
|
||||||
raise ValueError("Unexpected response format")
|
|
||||||
markdown_contents = [
|
|
||||||
result.get('markdown', '') for result in crawl_result
|
|
||||||
]
|
|
||||||
return '\n'.join(markdown_contents)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to crawl the URL and retrieve markdown: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_crawl_job(self, job_id: str) -> Dict:
|
|
||||||
r"""Check the status of a crawl job.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
job_id (str): The ID of the crawl job.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: The response including status of the crawl job.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the check process fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
return self.app.check_crawl_status(job_id)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to check the crawl job status: {e}")
|
|
||||||
|
|
||||||
def scrape(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Dict:
|
|
||||||
r"""To scrape a single URL. This function supports advanced scraping
|
|
||||||
by setting different parameters and returns the full scraped data as a
|
|
||||||
dictionary.
|
|
||||||
|
|
||||||
Reference: https://docs.firecrawl.dev/advanced-scraping-guide
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to read.
|
|
||||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
|
||||||
scrape request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: The scraped data.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the scrape process fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.app.scrape_url(url=url, params=params)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to scrape the URL: {e}")
|
|
||||||
|
|
||||||
def structured_scrape(self, url: str, response_format: BaseModel) -> Dict:
|
|
||||||
r"""Use LLM to extract structured data from given URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to read.
|
|
||||||
response_format (BaseModel): A pydantic model
|
|
||||||
that includes value types and field descriptions used to
|
|
||||||
generate a structured response by LLM. This schema helps
|
|
||||||
in defining the expected output format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: The content of the URL.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the scrape process fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = self.app.scrape_url(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
'formats': ['extract'],
|
|
||||||
'extract': {'schema': response_format.model_json_schema()},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return data.get("extract", {})
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to perform structured scrape: {e}")
|
|
||||||
|
|
||||||
def map_site(
|
|
||||||
self, url: str, params: Optional[Dict[str, Any]] = None
|
|
||||||
) -> list:
|
|
||||||
r"""Map a website to retrieve all accessible URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL of the site to map.
|
|
||||||
params (Optional[Dict[str, Any]]): Additional parameters for the
|
|
||||||
map request. Defaults to `None`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list containing the URLs found on the site.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the mapping process fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.app.map_url(url=url, params=params)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to map the site: {e}")
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Optional
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
from camel.types.enums import JinaReturnFormat
|
|
||||||
|
|
||||||
JINA_ENDPOINT = "https://r.jina.ai/"
|
|
||||||
|
|
||||||
|
|
||||||
class JinaURLReader:
|
|
||||||
r"""URL Reader provided by Jina AI. The output is cleaner and more
|
|
||||||
LLM-friendly than the URL Reader of UnstructuredIO. Can be configured to
|
|
||||||
replace the UnstructuredIO URL Reader in the pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (Optional[str], optional): The API key for Jina AI. If not
|
|
||||||
provided, the reader will have a lower rate limit. Defaults to
|
|
||||||
None.
|
|
||||||
return_format (ReturnFormat, optional): The level of detail
|
|
||||||
of the returned content, which is optimized for LLMs. For
|
|
||||||
now screenshots are not supported. Defaults to
|
|
||||||
ReturnFormat.DEFAULT.
|
|
||||||
json_response (bool, optional): Whether to return the response
|
|
||||||
in JSON format. Defaults to False.
|
|
||||||
timeout (int, optional): The maximum time in seconds to wait for
|
|
||||||
the page to be rendered. Defaults to 30.
|
|
||||||
**kwargs (Any): Additional keyword arguments, including proxies,
|
|
||||||
cookies, etc. It should align with the HTTP Header field and
|
|
||||||
value pairs listed in the reference.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://jina.ai/reader
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
return_format: JinaReturnFormat = JinaReturnFormat.DEFAULT,
|
|
||||||
json_response: bool = False,
|
|
||||||
timeout: int = 30,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
api_key = api_key or os.getenv('JINA_API_KEY')
|
|
||||||
if not api_key:
|
|
||||||
warn(
|
|
||||||
"JINA_API_KEY not set. This will result in a low rate limit "
|
|
||||||
"of Jina URL Reader. Get API key here: https://jina.ai/reader."
|
|
||||||
)
|
|
||||||
|
|
||||||
# if the following field not provided, it will be None
|
|
||||||
api_field = f"Bearer {api_key}" if api_key else None
|
|
||||||
json_field = "application/json" if json_response else None
|
|
||||||
|
|
||||||
raw_headers = {
|
|
||||||
"Authorization": api_field,
|
|
||||||
"X-Return-Format": return_format.value,
|
|
||||||
"Accept": json_field,
|
|
||||||
"X-Timeout": str(timeout),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# eliminate None values
|
|
||||||
self._headers = {k: v for k, v in raw_headers.items() if v}
|
|
||||||
|
|
||||||
def read_content(self, url: str) -> str:
|
|
||||||
r"""Reads the content of a URL and returns it as a string with
|
|
||||||
given form.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to read.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The content of the URL.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
full_url = f"{JINA_ENDPOINT}{url}"
|
|
||||||
try:
|
|
||||||
resp = requests.get(full_url, headers=self._headers)
|
|
||||||
resp.raise_for_status()
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to read content from {url}: {e}") from e
|
|
||||||
|
|
||||||
return resp.text
|
|
||||||
@@ -1,471 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import uuid
|
|
||||||
import warnings
|
|
||||||
from typing import (
|
|
||||||
IO,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from unstructured.documents.elements import Element
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
class UnstructuredIO:
|
|
||||||
r"""A class to handle various functionalities provided by the
|
|
||||||
Unstructured library, including version checking, parsing, cleaning,
|
|
||||||
extracting, staging, chunking data, and integrating with cloud
|
|
||||||
services like S3 and Azure for data connection.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://docs.unstructured.io/
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_element_from_text(
|
|
||||||
text: str,
|
|
||||||
element_id: Optional[str] = None,
|
|
||||||
embeddings: Optional[List[float]] = None,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
file_directory: Optional[str] = None,
|
|
||||||
last_modified: Optional[str] = None,
|
|
||||||
filetype: Optional[str] = None,
|
|
||||||
parent_id: Optional[str] = None,
|
|
||||||
) -> "Element":
|
|
||||||
r"""Creates a Text element from a given text input, with optional
|
|
||||||
metadata and embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The text content for the element.
|
|
||||||
element_id (Optional[str], optional): Unique identifier for the
|
|
||||||
element. (default: :obj:`None`)
|
|
||||||
embeddings (List[float], optional): A list of float
|
|
||||||
numbers representing the text embeddings.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
filename (Optional[str], optional): The name of the file the
|
|
||||||
element is associated with. (default: :obj:`None`)
|
|
||||||
file_directory (Optional[str], optional): The directory path where
|
|
||||||
the file is located. (default: :obj:`None`)
|
|
||||||
last_modified (Optional[str], optional): The last modified date of
|
|
||||||
the file. (default: :obj:`None`)
|
|
||||||
filetype (Optional[str], optional): The type of the file.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
parent_id (Optional[str], optional): The identifier of the parent
|
|
||||||
element. (default: :obj:`None`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Element: An instance of Text with the provided content and
|
|
||||||
metadata.
|
|
||||||
"""
|
|
||||||
from unstructured.documents.elements import ElementMetadata, Text
|
|
||||||
|
|
||||||
metadata = ElementMetadata(
|
|
||||||
filename=filename,
|
|
||||||
file_directory=file_directory,
|
|
||||||
last_modified=last_modified,
|
|
||||||
filetype=filetype,
|
|
||||||
parent_id=parent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Text(
|
|
||||||
text=text,
|
|
||||||
element_id=element_id or str(uuid.uuid4()),
|
|
||||||
metadata=metadata,
|
|
||||||
embeddings=embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_file_or_url(
|
|
||||||
input_path: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[List["Element"], None]:
|
|
||||||
r"""Loads a file or a URL and parses its contents into elements.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_path (str): Path to the file or URL to be parsed.
|
|
||||||
**kwargs: Extra kwargs passed to the partition function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[Element],None]: List of elements after parsing the file
|
|
||||||
or URL if success.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If the file does not exist at the path
|
|
||||||
specified.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
Supported file types:
|
|
||||||
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
|
|
||||||
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://unstructured-io.github.io/unstructured/
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from unstructured.partition.auto import partition
|
|
||||||
# Check if the input is a URL
|
|
||||||
parsed_url = urlparse(input_path)
|
|
||||||
# pdb.set_trace()
|
|
||||||
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
||||||
|
|
||||||
# Handling URL
|
|
||||||
if is_url:
|
|
||||||
try:
|
|
||||||
elements = partition(url=input_path, **kwargs)
|
|
||||||
return elements
|
|
||||||
except Exception:
|
|
||||||
warnings.warn(f"Failed to parse the URL: {input_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handling file
|
|
||||||
else:
|
|
||||||
# Check if the file exists
|
|
||||||
if not os.path.exists(input_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"The file {input_path} was not found."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read the file
|
|
||||||
try:
|
|
||||||
with open(input_path, "rb") as f:
|
|
||||||
elements = partition(file=f, **kwargs)
|
|
||||||
return elements
|
|
||||||
except Exception:
|
|
||||||
warnings.warn(f"Failed to partition the file: {input_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_bytes(
|
|
||||||
file: IO[bytes], **kwargs: Any
|
|
||||||
) -> Union[List["Element"], None]:
|
|
||||||
r"""Parses a bytes stream and converts its contents into elements.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (IO[bytes]): The file in bytes format to be parsed.
|
|
||||||
**kwargs: Extra kwargs passed to the partition function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[Element], None]: List of elements after parsing the file
|
|
||||||
if successful, otherwise `None`.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
Supported file types:
|
|
||||||
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt",
|
|
||||||
"org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx".
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://docs.unstructured.io/open-source/core-functionality/partitioning
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unstructured.partition.auto import partition
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use partition to process the bytes stream
|
|
||||||
elements = partition(file=file, **kwargs)
|
|
||||||
return elements
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f"Failed to partition the file stream: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def clean_text_data(
|
|
||||||
text: str,
|
|
||||||
clean_options: Optional[List[Tuple[str, Dict[str, Any]]]] = None,
|
|
||||||
) -> str:
|
|
||||||
r"""Cleans text data using a variety of cleaning functions provided by
|
|
||||||
the `unstructured` library.
|
|
||||||
|
|
||||||
This function applies multiple text cleaning utilities by calling the
|
|
||||||
`unstructured` library's cleaning bricks for operations like
|
|
||||||
replacing Unicode quotes, removing extra whitespace, dashes, non-ascii
|
|
||||||
characters, and more.
|
|
||||||
|
|
||||||
If no cleaning options are provided, a default set of cleaning
|
|
||||||
operations is applied. These defaults including operations
|
|
||||||
"replace_unicode_quotes", "clean_non_ascii_chars",
|
|
||||||
"group_broken_paragraphs", and "clean_extra_whitespace".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The text to be cleaned.
|
|
||||||
clean_options (dict): A dictionary specifying which cleaning
|
|
||||||
options to apply. The keys should match the names of the
|
|
||||||
cleaning functions, and the values should be dictionaries
|
|
||||||
containing the parameters for each function. Supported types:
|
|
||||||
'clean_extra_whitespace', 'clean_bullets',
|
|
||||||
'clean_ordered_bullets', 'clean_postfix', 'clean_prefix',
|
|
||||||
'clean_dashes', 'clean_trailing_punctuation',
|
|
||||||
'clean_non_ascii_chars', 'group_broken_paragraphs',
|
|
||||||
'remove_punctuation', 'replace_unicode_quotes',
|
|
||||||
'bytes_string_to_string', 'translate_text'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The cleaned text.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AttributeError: If a cleaning option does not correspond to a
|
|
||||||
valid cleaning function in `unstructured`.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
The 'options' dictionary keys must correspond to valid cleaning
|
|
||||||
brick names from the `unstructured` library.
|
|
||||||
Each brick's parameters must be provided in a nested dictionary
|
|
||||||
as the value for the key.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://unstructured-io.github.io/unstructured/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unstructured.cleaners.core import (
|
|
||||||
bytes_string_to_string,
|
|
||||||
clean_bullets,
|
|
||||||
clean_dashes,
|
|
||||||
clean_extra_whitespace,
|
|
||||||
clean_non_ascii_chars,
|
|
||||||
clean_ordered_bullets,
|
|
||||||
clean_postfix,
|
|
||||||
clean_prefix,
|
|
||||||
clean_trailing_punctuation,
|
|
||||||
group_broken_paragraphs,
|
|
||||||
remove_punctuation,
|
|
||||||
replace_unicode_quotes,
|
|
||||||
)
|
|
||||||
from unstructured.cleaners.translate import translate_text
|
|
||||||
|
|
||||||
cleaning_functions: Any = {
|
|
||||||
"clean_extra_whitespace": clean_extra_whitespace,
|
|
||||||
"clean_bullets": clean_bullets,
|
|
||||||
"clean_ordered_bullets": clean_ordered_bullets,
|
|
||||||
"clean_postfix": clean_postfix,
|
|
||||||
"clean_prefix": clean_prefix,
|
|
||||||
"clean_dashes": clean_dashes,
|
|
||||||
"clean_trailing_punctuation": clean_trailing_punctuation,
|
|
||||||
"clean_non_ascii_chars": clean_non_ascii_chars,
|
|
||||||
"group_broken_paragraphs": group_broken_paragraphs,
|
|
||||||
"remove_punctuation": remove_punctuation,
|
|
||||||
"replace_unicode_quotes": replace_unicode_quotes,
|
|
||||||
"bytes_string_to_string": bytes_string_to_string,
|
|
||||||
"translate_text": translate_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Define default clean options if none are provided
|
|
||||||
if clean_options is None:
|
|
||||||
clean_options = [
|
|
||||||
("replace_unicode_quotes", {}),
|
|
||||||
("clean_non_ascii_chars", {}),
|
|
||||||
("group_broken_paragraphs", {}),
|
|
||||||
("clean_extra_whitespace", {}),
|
|
||||||
]
|
|
||||||
|
|
||||||
cleaned_text = text
|
|
||||||
for func_name, params in clean_options:
|
|
||||||
if func_name in cleaning_functions:
|
|
||||||
cleaned_text = cleaning_functions[func_name](
|
|
||||||
cleaned_text, **params
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"'{func_name}' is not a valid function in "
|
|
||||||
"`Unstructured IO`."
|
|
||||||
)
|
|
||||||
|
|
||||||
return cleaned_text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_data_from_text(
|
|
||||||
text: str,
|
|
||||||
extract_type: Literal[
|
|
||||||
'extract_datetimetz',
|
|
||||||
'extract_email_address',
|
|
||||||
'extract_ip_address',
|
|
||||||
'extract_ip_address_name',
|
|
||||||
'extract_mapi_id',
|
|
||||||
'extract_ordered_bullets',
|
|
||||||
'extract_text_after',
|
|
||||||
'extract_text_before',
|
|
||||||
'extract_us_phone_number',
|
|
||||||
],
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
r"""Extracts various types of data from text using functions from
|
|
||||||
unstructured.cleaners.extract.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Text to extract data from.
|
|
||||||
extract_type (Literal['extract_datetimetz',
|
|
||||||
'extract_email_address', 'extract_ip_address',
|
|
||||||
'extract_ip_address_name', 'extract_mapi_id',
|
|
||||||
'extract_ordered_bullets', 'extract_text_after',
|
|
||||||
'extract_text_before', 'extract_us_phone_number']): Type of
|
|
||||||
data to extract.
|
|
||||||
**kwargs: Additional keyword arguments for specific
|
|
||||||
extraction functions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The extracted data, type depends on extract_type.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://unstructured-io.github.io/unstructured/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unstructured.cleaners.extract import (
|
|
||||||
extract_datetimetz,
|
|
||||||
extract_email_address,
|
|
||||||
extract_ip_address,
|
|
||||||
extract_ip_address_name,
|
|
||||||
extract_mapi_id,
|
|
||||||
extract_ordered_bullets,
|
|
||||||
extract_text_after,
|
|
||||||
extract_text_before,
|
|
||||||
extract_us_phone_number,
|
|
||||||
)
|
|
||||||
|
|
||||||
extraction_functions: Any = {
|
|
||||||
"extract_datetimetz": extract_datetimetz,
|
|
||||||
"extract_email_address": extract_email_address,
|
|
||||||
"extract_ip_address": extract_ip_address,
|
|
||||||
"extract_ip_address_name": extract_ip_address_name,
|
|
||||||
"extract_mapi_id": extract_mapi_id,
|
|
||||||
"extract_ordered_bullets": extract_ordered_bullets,
|
|
||||||
"extract_text_after": extract_text_after,
|
|
||||||
"extract_text_before": extract_text_before,
|
|
||||||
"extract_us_phone_number": extract_us_phone_number,
|
|
||||||
}
|
|
||||||
|
|
||||||
if extract_type not in extraction_functions:
|
|
||||||
raise ValueError(f"Unsupported extract_type: {extract_type}")
|
|
||||||
|
|
||||||
return extraction_functions[extract_type](text, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def stage_elements(
|
|
||||||
elements: List[Any],
|
|
||||||
stage_type: Literal[
|
|
||||||
'convert_to_csv',
|
|
||||||
'convert_to_dataframe',
|
|
||||||
'convert_to_dict',
|
|
||||||
'dict_to_elements',
|
|
||||||
'stage_csv_for_prodigy',
|
|
||||||
'stage_for_prodigy',
|
|
||||||
'stage_for_baseplate',
|
|
||||||
'stage_for_datasaur',
|
|
||||||
'stage_for_label_box',
|
|
||||||
'stage_for_label_studio',
|
|
||||||
'stage_for_weaviate',
|
|
||||||
],
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[str, List[Dict], Any]:
|
|
||||||
r"""Stages elements for various platforms based on the
|
|
||||||
specified staging type.
|
|
||||||
|
|
||||||
This function applies multiple staging utilities to format data
|
|
||||||
for different NLP annotation and machine learning tools. It uses
|
|
||||||
the 'unstructured.staging' module's functions for operations like
|
|
||||||
converting to CSV, DataFrame, dictionary, or formatting for
|
|
||||||
specific platforms like Prodigy, etc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
elements (List[Any]): List of Element objects to be staged.
|
|
||||||
stage_type (Literal['convert_to_csv', 'convert_to_dataframe',
|
|
||||||
'convert_to_dict', 'dict_to_elements',
|
|
||||||
'stage_csv_for_prodigy', 'stage_for_prodigy',
|
|
||||||
'stage_for_baseplate', 'stage_for_datasaur',
|
|
||||||
'stage_for_label_box', 'stage_for_label_studio',
|
|
||||||
'stage_for_weaviate']): Type of staging to perform.
|
|
||||||
**kwargs: Additional keyword arguments specific to
|
|
||||||
the staging type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, List[Dict], Any]: Staged data in the
|
|
||||||
format appropriate for the specified staging type.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the staging type is not supported or a required
|
|
||||||
argument is missing.
|
|
||||||
References:
|
|
||||||
https://unstructured-io.github.io/unstructured/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unstructured.staging import (
|
|
||||||
base,
|
|
||||||
baseplate,
|
|
||||||
datasaur,
|
|
||||||
label_box,
|
|
||||||
label_studio,
|
|
||||||
prodigy,
|
|
||||||
weaviate,
|
|
||||||
)
|
|
||||||
|
|
||||||
staging_functions: Any = {
|
|
||||||
"convert_to_csv": base.convert_to_csv,
|
|
||||||
"convert_to_dataframe": base.convert_to_dataframe,
|
|
||||||
"convert_to_dict": base.convert_to_dict,
|
|
||||||
"dict_to_elements": base.dict_to_elements,
|
|
||||||
"stage_csv_for_prodigy": lambda els,
|
|
||||||
**kw: prodigy.stage_csv_for_prodigy(els, kw.get('metadata', [])),
|
|
||||||
"stage_for_prodigy": lambda els, **kw: prodigy.stage_for_prodigy(
|
|
||||||
els, kw.get('metadata', [])
|
|
||||||
),
|
|
||||||
"stage_for_baseplate": baseplate.stage_for_baseplate,
|
|
||||||
"stage_for_datasaur": lambda els,
|
|
||||||
**kw: datasaur.stage_for_datasaur(els, kw.get('entities', [])),
|
|
||||||
"stage_for_label_box": lambda els,
|
|
||||||
**kw: label_box.stage_for_label_box(els, **kw),
|
|
||||||
"stage_for_label_studio": lambda els,
|
|
||||||
**kw: label_studio.stage_for_label_studio(els, **kw),
|
|
||||||
"stage_for_weaviate": weaviate.stage_for_weaviate,
|
|
||||||
}
|
|
||||||
|
|
||||||
if stage_type not in staging_functions:
|
|
||||||
raise ValueError(f"Unsupported stage type: {stage_type}")
|
|
||||||
|
|
||||||
return staging_functions[stage_type](elements, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def chunk_elements(
|
|
||||||
elements: List["Element"], chunk_type: str, **kwargs
|
|
||||||
) -> List["Element"]:
|
|
||||||
r"""Chunks elements by titles.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
elements (List[Element]): List of Element objects to be chunked.
|
|
||||||
chunk_type (str): Type chunk going to apply. Supported types:
|
|
||||||
'chunk_by_title'.
|
|
||||||
**kwargs: Additional keyword arguments for chunking.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: List of chunked sections.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://unstructured-io.github.io/unstructured/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unstructured.chunking.title import chunk_by_title
|
|
||||||
|
|
||||||
chunking_functions = {
|
|
||||||
"chunk_by_title": chunk_by_title,
|
|
||||||
}
|
|
||||||
|
|
||||||
if chunk_type not in chunking_functions:
|
|
||||||
raise ValueError(f"Unsupported chunk type: {chunk_type}")
|
|
||||||
|
|
||||||
# Format chunks into a list of dictionaries (or your preferred format)
|
|
||||||
return chunking_functions[chunk_type](elements, **kwargs)
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Create a private logger
|
|
||||||
_logger = logging.getLogger('camel')
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_library_logging():
|
|
||||||
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').lower() == 'true':
|
|
||||||
return
|
|
||||||
|
|
||||||
if not logging.root.handlers and not _logger.handlers:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
stream=sys.stdout,
|
|
||||||
)
|
|
||||||
logging.setLoggerClass(logging.Logger)
|
|
||||||
_logger.info("Camel library logging has been configured.")
|
|
||||||
else:
|
|
||||||
_logger.debug("Existing logger configuration found, using that.")
|
|
||||||
|
|
||||||
|
|
||||||
def disable_logging():
|
|
||||||
r"""Disable all logging for the Camel library.
|
|
||||||
|
|
||||||
This function sets the log level to a value higher than CRITICAL,
|
|
||||||
effectively disabling all log messages, and adds a NullHandler to
|
|
||||||
suppress any potential warnings about no handlers being found.
|
|
||||||
"""
|
|
||||||
os.environ['CAMEL_LOGGING_DISABLED'] = 'true'
|
|
||||||
_logger.setLevel(logging.CRITICAL + 1)
|
|
||||||
# Avoid adding multiple NullHandlers
|
|
||||||
if not any(
|
|
||||||
isinstance(handler, logging.NullHandler)
|
|
||||||
for handler in _logger.handlers
|
|
||||||
):
|
|
||||||
_logger.addHandler(logging.NullHandler())
|
|
||||||
_logger.debug("Logging has been disabled.")
|
|
||||||
|
|
||||||
|
|
||||||
def enable_logging():
|
|
||||||
r"""Enable logging for the Camel library.
|
|
||||||
|
|
||||||
This function re-enables logging if it was previously disabled,
|
|
||||||
and configures the library logging using the default settings.
|
|
||||||
If the logging is already configured,
|
|
||||||
this function does not change its configuration.
|
|
||||||
"""
|
|
||||||
os.environ['CAMEL_LOGGING_DISABLED'] = 'false'
|
|
||||||
_configure_library_logging()
|
|
||||||
|
|
||||||
|
|
||||||
def set_log_level(level):
|
|
||||||
r"""Set the logging level for the Camel library.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level (Union[str, int]): The logging level to set. This can be a string
|
|
||||||
(e.g., 'INFO') or a logging level constant (e.g., logging.INFO,
|
|
||||||
logging.DEBUG).
|
|
||||||
See https://docs.python.org/3/library/logging.html#levels
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the provided level is not a valid logging level.
|
|
||||||
"""
|
|
||||||
valid_levels = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
|
||||||
if isinstance(level, str):
|
|
||||||
if level.upper() not in valid_levels:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid logging level."
|
|
||||||
f" Choose from: {', '.join(valid_levels)}"
|
|
||||||
)
|
|
||||||
level = level.upper()
|
|
||||||
elif not isinstance(level, int):
|
|
||||||
raise ValueError(
|
|
||||||
"Logging level must be an option from the logging module."
|
|
||||||
)
|
|
||||||
|
|
||||||
_logger.setLevel(level)
|
|
||||||
_logger.debug(f"Logging level set to: {logging.getLevelName(level)}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name):
|
|
||||||
r"""Get a logger with the specified name, prefixed with 'camel.'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name to be appended to 'camel.' to create the logger.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logging.Logger: A logger instance with the name 'camel.{name}'.
|
|
||||||
"""
|
|
||||||
return logging.getLogger(f'camel.{name}')
|
|
||||||
|
|
||||||
|
|
||||||
# Lazy configuration: Only configure logging if explicitly enabled.
|
|
||||||
if os.environ.get('CAMEL_LOGGING_DISABLED', 'False').strip().lower() != 'true':
|
|
||||||
_configure_library_logging()
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .agent_memories import (
|
|
||||||
ChatHistoryMemory,
|
|
||||||
LongtermAgentMemory,
|
|
||||||
VectorDBMemory,
|
|
||||||
)
|
|
||||||
from .base import AgentMemory, BaseContextCreator, MemoryBlock
|
|
||||||
from .blocks.chat_history_block import ChatHistoryBlock
|
|
||||||
from .blocks.vectordb_block import VectorDBBlock
|
|
||||||
from .context_creators.score_based import ScoreBasedContextCreator
|
|
||||||
from .records import ContextRecord, MemoryRecord
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'MemoryRecord',
|
|
||||||
'ContextRecord',
|
|
||||||
'MemoryBlock',
|
|
||||||
"AgentMemory",
|
|
||||||
'BaseContextCreator',
|
|
||||||
'ScoreBasedContextCreator',
|
|
||||||
'ChatHistoryMemory',
|
|
||||||
'VectorDBMemory',
|
|
||||||
'ChatHistoryBlock',
|
|
||||||
'VectorDBBlock',
|
|
||||||
'LongtermAgentMemory',
|
|
||||||
]
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from camel.memories.base import AgentMemory, BaseContextCreator
|
|
||||||
from camel.memories.blocks import ChatHistoryBlock, VectorDBBlock
|
|
||||||
from camel.memories.records import ContextRecord, MemoryRecord
|
|
||||||
from camel.storages import BaseKeyValueStorage, BaseVectorStorage
|
|
||||||
from camel.types import OpenAIBackendRole
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistoryMemory(AgentMemory):
|
|
||||||
r"""An agent memory wrapper of :obj:`ChatHistoryBlock`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_creator (BaseContextCreator): A model context creator.
|
|
||||||
storage (BaseKeyValueStorage, optional): A storage backend for storing
|
|
||||||
chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
|
|
||||||
will be used. (default: :obj:`None`)
|
|
||||||
window_size (int, optional): The number of recent chat messages to
|
|
||||||
retrieve. If not provided, the entire chat history will be
|
|
||||||
retrieved. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
context_creator: BaseContextCreator,
|
|
||||||
storage: Optional[BaseKeyValueStorage] = None,
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
if window_size is not None and not isinstance(window_size, int):
|
|
||||||
raise TypeError("`window_size` must be an integer or None.")
|
|
||||||
if window_size is not None and window_size < 0:
|
|
||||||
raise ValueError("`window_size` must be non-negative.")
|
|
||||||
self._context_creator = context_creator
|
|
||||||
self._window_size = window_size
|
|
||||||
self._chat_history_block = ChatHistoryBlock(storage=storage)
|
|
||||||
|
|
||||||
def retrieve(self) -> List[ContextRecord]:
|
|
||||||
return self._chat_history_block.retrieve(self._window_size)
|
|
||||||
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
self._chat_history_block.write_records(records)
|
|
||||||
|
|
||||||
def get_context_creator(self) -> BaseContextCreator:
|
|
||||||
return self._context_creator
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self._chat_history_block.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBMemory(AgentMemory):
|
|
||||||
r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries
|
|
||||||
messages stored in the vector database. Notice that the most recent
|
|
||||||
messages will not be added to the context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_creator (BaseContextCreator): A model context creator.
|
|
||||||
storage (BaseVectorStorage, optional): A vector storage storage. If
|
|
||||||
`None`, an :obj:`QdrantStorage` will be used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
retrieve_limit (int, optional): The maximum number of messages
|
|
||||||
to be added into the context. (default: :obj:`3`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
context_creator: BaseContextCreator,
|
|
||||||
storage: Optional[BaseVectorStorage] = None,
|
|
||||||
retrieve_limit: int = 3,
|
|
||||||
) -> None:
|
|
||||||
self._context_creator = context_creator
|
|
||||||
self._retrieve_limit = retrieve_limit
|
|
||||||
self._vectordb_block = VectorDBBlock(storage=storage)
|
|
||||||
|
|
||||||
self._current_topic: str = ""
|
|
||||||
|
|
||||||
def retrieve(self) -> List[ContextRecord]:
|
|
||||||
return self._vectordb_block.retrieve(
|
|
||||||
self._current_topic,
|
|
||||||
limit=self._retrieve_limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
# Assume the last user input is the current topic.
|
|
||||||
for record in records:
|
|
||||||
if record.role_at_backend == OpenAIBackendRole.USER:
|
|
||||||
self._current_topic = record.message.content
|
|
||||||
self._vectordb_block.write_records(records)
|
|
||||||
|
|
||||||
def get_context_creator(self) -> BaseContextCreator:
|
|
||||||
return self._context_creator
|
|
||||||
|
|
||||||
|
|
||||||
class LongtermAgentMemory(AgentMemory):
|
|
||||||
r"""An implementation of the :obj:`AgentMemory` abstract base class for
|
|
||||||
augmenting ChatHistoryMemory with VectorDBMemory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_creator (BaseContextCreator): A model context creator.
|
|
||||||
chat_history_block (Optional[ChatHistoryBlock], optional): A chat
|
|
||||||
history block. If `None`, a :obj:`ChatHistoryBlock` will be used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
vector_db_block (Optional[VectorDBBlock], optional): A vector database
|
|
||||||
block. If `None`, a :obj:`VectorDBBlock` will be used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
retrieve_limit (int, optional): The maximum number of messages
|
|
||||||
to be added into the context. (default: :obj:`3`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
context_creator: BaseContextCreator,
|
|
||||||
chat_history_block: Optional[ChatHistoryBlock] = None,
|
|
||||||
vector_db_block: Optional[VectorDBBlock] = None,
|
|
||||||
retrieve_limit: int = 3,
|
|
||||||
) -> None:
|
|
||||||
self.chat_history_block = chat_history_block or ChatHistoryBlock()
|
|
||||||
self.vector_db_block = vector_db_block or VectorDBBlock()
|
|
||||||
self.retrieve_limit = retrieve_limit
|
|
||||||
self._context_creator = context_creator
|
|
||||||
self._current_topic: str = ""
|
|
||||||
|
|
||||||
def get_context_creator(self) -> BaseContextCreator:
|
|
||||||
r"""Returns the context creator used by the memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseContextCreator: The context creator used by the memory.
|
|
||||||
"""
|
|
||||||
return self._context_creator
|
|
||||||
|
|
||||||
def retrieve(self) -> List[ContextRecord]:
|
|
||||||
r"""Retrieves context records from both the chat history and the vector
|
|
||||||
database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ContextRecord]: A list of context records retrieved from both
|
|
||||||
the chat history and the vector database.
|
|
||||||
"""
|
|
||||||
chat_history = self.chat_history_block.retrieve()
|
|
||||||
vector_db_retrieve = self.vector_db_block.retrieve(
|
|
||||||
self._current_topic, self.retrieve_limit
|
|
||||||
)
|
|
||||||
return chat_history[:1] + vector_db_retrieve + chat_history[1:]
|
|
||||||
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
r"""Converts the provided chat messages into vector representations and
|
|
||||||
writes them to the vector database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[MemoryRecord]): Messages to be added to the vector
|
|
||||||
database.
|
|
||||||
"""
|
|
||||||
self.vector_db_block.write_records(records)
|
|
||||||
self.chat_history_block.write_records(records)
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
if record.role_at_backend == OpenAIBackendRole.USER:
|
|
||||||
self._current_topic = record.message.content
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
r"""Removes all records from the memory."""
|
|
||||||
self.chat_history_block.clear()
|
|
||||||
self.vector_db_block.clear()
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from camel.memories.records import ContextRecord, MemoryRecord
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.utils import BaseTokenCounter
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBlock(ABC):
|
|
||||||
r"""An abstract class serves as the fundamental component within the agent
|
|
||||||
memory system. This class is equipped with "write" and "clear" functions.
|
|
||||||
However, it intentionally does not define a retrieval interface, as the
|
|
||||||
structure of the data to be retrieved may vary in different types of
|
|
||||||
memory blocks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
r"""Writes records to the memory, appending them to existing ones.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[MemoryRecord]): Records to be added to the memory.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def write_record(self, record: MemoryRecord) -> None:
|
|
||||||
r"""Writes a record to the memory, appending it to existing ones.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record (MemoryRecord): Record to be added to the memory.
|
|
||||||
"""
|
|
||||||
self.write_records([record])
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear(self) -> None:
|
|
||||||
r"""Clears all messages from the memory."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseContextCreator(ABC):
|
|
||||||
r"""An abstract base class defining the interface for context creation
|
|
||||||
strategies.
|
|
||||||
|
|
||||||
This class provides a foundational structure for different strategies to
|
|
||||||
generate conversational context from a list of context records. The
|
|
||||||
primary goal is to create a context that is aligned with a specified token
|
|
||||||
count limit, allowing subclasses to define their specific approach.
|
|
||||||
|
|
||||||
Subclasses should implement the :obj:`token_counter`,:obj: `token_limit`,
|
|
||||||
and :obj:`create_context` methods to provide specific context creation
|
|
||||||
logic.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
token_counter (BaseTokenCounter): A token counter instance responsible
|
|
||||||
for counting tokens in a message.
|
|
||||||
token_limit (int): The maximum number of tokens allowed in the
|
|
||||||
generated context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def token_limit(self) -> int:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_context(
|
|
||||||
self,
|
|
||||||
records: List[ContextRecord],
|
|
||||||
) -> Tuple[List[OpenAIMessage], int]:
|
|
||||||
r"""An abstract method to create conversational context from the chat
|
|
||||||
history.
|
|
||||||
|
|
||||||
Constructs the context from provided records. The specifics of how this
|
|
||||||
is done and how the token count is managed should be provided by
|
|
||||||
subclasses implementing this method. The output messages order
|
|
||||||
should keep same as the input order.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[ContextRecord]): A list of context records from
|
|
||||||
which to generate the context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
|
|
||||||
context in OpenAIMessage format and the total token count.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory(MemoryBlock, ABC):
|
|
||||||
r"""Represents a specialized form of `MemoryBlock`, uniquely designed for
|
|
||||||
direct integration with an agent. Two key abstract functions, "retrieve"
|
|
||||||
and "get_context_creator", are used for generating model context based on
|
|
||||||
the memory records stored within the AgentMemory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def retrieve(self) -> List[ContextRecord]:
|
|
||||||
r"""Get a record list from the memory for creating model context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ContextRecord]: A record list for creating model context.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_context_creator(self) -> BaseContextCreator:
|
|
||||||
r"""Gets context creator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseContextCreator: A model context creator.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_context(self) -> Tuple[List[OpenAIMessage], int]:
|
|
||||||
r"""Gets chat context with a proper size for the agent from the memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(List[OpenAIMessage], int): A tuple containing the constructed
|
|
||||||
context in OpenAIMessage format and the total token count.
|
|
||||||
"""
|
|
||||||
return self.get_context_creator().create_context(self.retrieve())
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .chat_history_block import ChatHistoryBlock
|
|
||||||
from .vectordb_block import VectorDBBlock
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'ChatHistoryBlock',
|
|
||||||
'VectorDBBlock',
|
|
||||||
]
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from camel.memories.base import MemoryBlock
|
|
||||||
from camel.memories.records import ContextRecord, MemoryRecord
|
|
||||||
from camel.storages import BaseKeyValueStorage, InMemoryKeyValueStorage
|
|
||||||
from camel.types import OpenAIBackendRole
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistoryBlock(MemoryBlock):
|
|
||||||
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
|
|
||||||
maintaining a record of chat histories.
|
|
||||||
|
|
||||||
This memory block helps manage conversation histories with a key-value
|
|
||||||
storage backend, either provided by the user or using a default
|
|
||||||
in-memory storage. It offers a windowed approach to retrieving chat
|
|
||||||
histories, allowing users to specify how many recent messages they'd
|
|
||||||
like to fetch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage (BaseKeyValueStorage, optional): A storage mechanism for
|
|
||||||
storing chat history. If `None`, an :obj:`InMemoryKeyValueStorage`
|
|
||||||
will be used. (default: :obj:`None`)
|
|
||||||
keep_rate (float, optional): In historical messages, the score of the
|
|
||||||
last message is 1.0, and with each step taken backward, the score
|
|
||||||
of the message is multiplied by the `keep_rate`. Higher `keep_rate`
|
|
||||||
leads to high possiblity to keep history messages during context
|
|
||||||
creation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
storage: Optional[BaseKeyValueStorage] = None,
|
|
||||||
keep_rate: float = 0.9,
|
|
||||||
) -> None:
|
|
||||||
if keep_rate > 1 or keep_rate < 0:
|
|
||||||
raise ValueError("`keep_rate` should be in [0,1]")
|
|
||||||
self.storage = storage or InMemoryKeyValueStorage()
|
|
||||||
self.keep_rate = keep_rate
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
self,
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
) -> List[ContextRecord]:
|
|
||||||
r"""Retrieves records with a proper size for the agent from the memory
|
|
||||||
based on the window size or fetches the entire chat history if no
|
|
||||||
window size is specified.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
window_size (int, optional): Specifies the number of recent chat
|
|
||||||
messages to retrieve. If not provided, the entire chat history
|
|
||||||
will be retrieved. (default: :obj:`None`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ContextRecord]: A list of retrieved records.
|
|
||||||
"""
|
|
||||||
record_dicts = self.storage.load()
|
|
||||||
if len(record_dicts) == 0:
|
|
||||||
warnings.warn("The `ChatHistoryMemory` is empty.")
|
|
||||||
return list()
|
|
||||||
|
|
||||||
chat_records: List[MemoryRecord] = []
|
|
||||||
truncate_idx = -window_size if window_size is not None else 0
|
|
||||||
for record_dict in record_dicts[truncate_idx:]:
|
|
||||||
chat_records.append(MemoryRecord.from_dict(record_dict))
|
|
||||||
|
|
||||||
# We assume that, in the chat history memory, the closer the record is
|
|
||||||
# to the current message, the more score it will be.
|
|
||||||
output_records = []
|
|
||||||
score = 1.0
|
|
||||||
for record in reversed(chat_records):
|
|
||||||
if record.role_at_backend == OpenAIBackendRole.SYSTEM:
|
|
||||||
# System messages are always kept.
|
|
||||||
output_records.append(
|
|
||||||
ContextRecord(memory_record=record, score=1.0)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Other messages' score drops down gradually
|
|
||||||
score *= self.keep_rate
|
|
||||||
output_records.append(
|
|
||||||
ContextRecord(memory_record=record, score=score)
|
|
||||||
)
|
|
||||||
|
|
||||||
output_records.reverse()
|
|
||||||
return output_records
|
|
||||||
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
r"""Writes memory records to the memory. Additionally, performs
|
|
||||||
validation checks on the messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[MemoryRecord]): Memory records to be added to the
|
|
||||||
memory.
|
|
||||||
"""
|
|
||||||
stored_records = []
|
|
||||||
for record in records:
|
|
||||||
stored_records.append(record.to_dict())
|
|
||||||
self.storage.save(stored_records)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
r"""Clears all chat messages from the memory."""
|
|
||||||
self.storage.clear()
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from camel.embeddings import BaseEmbedding, OpenAIEmbedding
|
|
||||||
from camel.memories.base import MemoryBlock
|
|
||||||
from camel.memories.records import ContextRecord, MemoryRecord
|
|
||||||
from camel.storages.vectordb_storages import (
|
|
||||||
BaseVectorStorage,
|
|
||||||
QdrantStorage,
|
|
||||||
VectorDBQuery,
|
|
||||||
VectorRecord,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBBlock(MemoryBlock):
|
|
||||||
r"""An implementation of the :obj:`MemoryBlock` abstract base class for
|
|
||||||
maintaining and retrieving information using vector embeddings within a
|
|
||||||
vector database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage (Optional[BaseVectorStorage], optional): The storage mechanism
|
|
||||||
for the vector database. Defaults to in-memory :obj:`Qdrant` if not
|
|
||||||
provided. (default: :obj:`None`)
|
|
||||||
embedding (Optional[BaseEmbedding], optional): Embedding mechanism to
|
|
||||||
convert chat messages into vector representations. Defaults to
|
|
||||||
:obj:`OpenAiEmbedding` if not provided. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
storage: Optional[BaseVectorStorage] = None,
|
|
||||||
embedding: Optional[BaseEmbedding] = None,
|
|
||||||
) -> None:
|
|
||||||
self.embedding = embedding or OpenAIEmbedding()
|
|
||||||
self.vector_dim = self.embedding.get_output_dim()
|
|
||||||
self.storage = storage or QdrantStorage(vector_dim=self.vector_dim)
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
limit: int = 3,
|
|
||||||
) -> List[ContextRecord]:
|
|
||||||
r"""Retrieves similar records from the vector database based on the
|
|
||||||
content of the keyword.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword (str): This string will be converted into a vector
|
|
||||||
representation to query the database.
|
|
||||||
limit (int, optional): The maximum number of similar messages to
|
|
||||||
retrieve. (default: :obj:`3`).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ContextRecord]: A list of memory records retrieved from the
|
|
||||||
vector database based on similarity to :obj:`current_state`.
|
|
||||||
"""
|
|
||||||
query_vector = self.embedding.embed(keyword)
|
|
||||||
results = self.storage.query(
|
|
||||||
VectorDBQuery(query_vector=query_vector, top_k=limit)
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ContextRecord(
|
|
||||||
memory_record=MemoryRecord.from_dict(result.record.payload),
|
|
||||||
score=result.similarity,
|
|
||||||
)
|
|
||||||
for result in results
|
|
||||||
if result.record.payload is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
def write_records(self, records: List[MemoryRecord]) -> None:
|
|
||||||
"""
|
|
||||||
Converts the provided chat messages into vector representations and
|
|
||||||
writes them to the vector database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[MemoryRecord]): Memory records to be added to the
|
|
||||||
memory.
|
|
||||||
"""
|
|
||||||
v_records = [
|
|
||||||
VectorRecord(
|
|
||||||
vector=self.embedding.embed(record.message.content),
|
|
||||||
payload=record.to_dict(),
|
|
||||||
id=str(record.uuid),
|
|
||||||
)
|
|
||||||
for record in records
|
|
||||||
]
|
|
||||||
self.storage.add(v_records)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
r"""Removes all records from the vector database memory."""
|
|
||||||
self.storage.clear()
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .score_based import ScoreBasedContextCreator
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'ScoreBasedContextCreator',
|
|
||||||
]
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from camel.memories.base import BaseContextCreator
|
|
||||||
from camel.memories.records import ContextRecord
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.utils import BaseTokenCounter
|
|
||||||
|
|
||||||
|
|
||||||
class _ContextUnit(BaseModel):
|
|
||||||
idx: int
|
|
||||||
record: ContextRecord
|
|
||||||
num_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreBasedContextCreator(BaseContextCreator):
|
|
||||||
r"""A default implementation of context creation strategy, which inherits
|
|
||||||
from :obj:`BaseContextCreator`.
|
|
||||||
|
|
||||||
This class provides a strategy to generate a conversational context from
|
|
||||||
a list of chat history records while ensuring the total token count of
|
|
||||||
the context does not exceed a specified limit. It prunes messages based
|
|
||||||
on their score if the total token count exceeds the limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_counter (BaseTokenCounter): An instance responsible for counting
|
|
||||||
tokens in a message.
|
|
||||||
token_limit (int): The maximum number of tokens allowed in the
|
|
||||||
generated context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, token_counter: BaseTokenCounter, token_limit: int
|
|
||||||
) -> None:
|
|
||||||
self._token_counter = token_counter
|
|
||||||
self._token_limit = token_limit
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
return self._token_counter
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_limit(self) -> int:
|
|
||||||
return self._token_limit
|
|
||||||
|
|
||||||
def create_context(
|
|
||||||
self,
|
|
||||||
records: List[ContextRecord],
|
|
||||||
) -> Tuple[List[OpenAIMessage], int]:
|
|
||||||
r"""Creates conversational context from chat history while respecting
|
|
||||||
token limits.
|
|
||||||
|
|
||||||
Constructs the context from provided records and ensures that the total
|
|
||||||
token count does not exceed the specified limit by pruning the least
|
|
||||||
score messages if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records (List[ContextRecord]): A list of message records from which
|
|
||||||
to generate the context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
|
|
||||||
context in OpenAIMessage format and the total token count.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If it's impossible to create a valid context without
|
|
||||||
exceeding the token limit.
|
|
||||||
"""
|
|
||||||
# Create unique context units list
|
|
||||||
uuid_set = set()
|
|
||||||
context_units = []
|
|
||||||
for idx, record in enumerate(records):
|
|
||||||
if record.memory_record.uuid not in uuid_set:
|
|
||||||
uuid_set.add(record.memory_record.uuid)
|
|
||||||
context_units.append(
|
|
||||||
_ContextUnit(
|
|
||||||
idx=idx,
|
|
||||||
record=record,
|
|
||||||
num_tokens=self.token_counter.count_tokens_from_messages(
|
|
||||||
[record.memory_record.to_openai_message()]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: optimize the process, may give information back to memory
|
|
||||||
|
|
||||||
# If not exceed token limit, simply return
|
|
||||||
total_tokens = sum([unit.num_tokens for unit in context_units])
|
|
||||||
if total_tokens <= self.token_limit:
|
|
||||||
return self._create_output(context_units)
|
|
||||||
|
|
||||||
# Sort by score
|
|
||||||
context_units = sorted(
|
|
||||||
context_units, key=lambda unit: unit.record.score
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove the least score messages until total token number is smaller
|
|
||||||
# than token limit
|
|
||||||
truncate_idx = None
|
|
||||||
for i, unit in enumerate(context_units):
|
|
||||||
if unit.record.score == 1:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot create context: exceed token limit.", total_tokens
|
|
||||||
)
|
|
||||||
total_tokens -= unit.num_tokens
|
|
||||||
if total_tokens <= self.token_limit:
|
|
||||||
truncate_idx = i
|
|
||||||
break
|
|
||||||
if truncate_idx is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot create context: exceed token limit.", total_tokens
|
|
||||||
)
|
|
||||||
return self._create_output(context_units[truncate_idx + 1 :])
|
|
||||||
|
|
||||||
def _create_output(
|
|
||||||
self, context_units: List[_ContextUnit]
|
|
||||||
) -> Tuple[List[OpenAIMessage], int]:
|
|
||||||
r"""Helper method to generate output from context units.
|
|
||||||
|
|
||||||
This method converts the provided context units into a format suitable
|
|
||||||
for output, specifically a list of OpenAIMessages and an integer
|
|
||||||
representing the total token count.
|
|
||||||
"""
|
|
||||||
context_units = sorted(context_units, key=lambda unit: unit.idx)
|
|
||||||
return [
|
|
||||||
unit.record.memory_record.to_openai_message()
|
|
||||||
for unit in context_units
|
|
||||||
], sum([unit.num_tokens for unit in context_units])
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import Any, ClassVar, Dict
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
|
||||||
|
|
||||||
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
|
|
||||||
from camel.types import OpenAIBackendRole
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRecord(BaseModel):
|
|
||||||
r"""The basic message storing unit in the CAMEL memory system.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
message (BaseMessage): The main content of the record.
|
|
||||||
role_at_backend (OpenAIBackendRole): An enumeration value representing
|
|
||||||
the role this message played at the OpenAI backend. Note that this
|
|
||||||
value is different from the :obj:`RoleType` used in the CAMEL role
|
|
||||||
playing system.
|
|
||||||
uuid (UUID, optional): A universally unique identifier for this record.
|
|
||||||
This is used to uniquely identify this record in the memory system.
|
|
||||||
If not given, it will be assigned with a random UUID.
|
|
||||||
extra_info (Dict[str, str], optional): A dictionary of additional
|
|
||||||
key-value pairs that provide more information. If not given, it
|
|
||||||
will be an empty `Dict`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
message: BaseMessage
|
|
||||||
role_at_backend: OpenAIBackendRole
|
|
||||||
uuid: UUID = Field(default_factory=uuid4)
|
|
||||||
extra_info: Dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
_MESSAGE_TYPES: ClassVar[dict] = {
|
|
||||||
"BaseMessage": BaseMessage,
|
|
||||||
"FunctionCallingMessage": FunctionCallingMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
|
|
||||||
r"""Reconstruct a :obj:`MemoryRecord` from the input dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record_dict(Dict[str, Any]): A dict generated by :meth:`to_dict`.
|
|
||||||
"""
|
|
||||||
message_cls = cls._MESSAGE_TYPES[record_dict["message"]["__class__"]]
|
|
||||||
kwargs: Dict = record_dict["message"].copy()
|
|
||||||
kwargs.pop("__class__")
|
|
||||||
reconstructed_message = message_cls(**kwargs)
|
|
||||||
return cls(
|
|
||||||
uuid=UUID(record_dict["uuid"]),
|
|
||||||
message=reconstructed_message,
|
|
||||||
role_at_backend=record_dict["role_at_backend"],
|
|
||||||
extra_info=record_dict["extra_info"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
r"""Convert the :obj:`MemoryRecord` to a dict for serialization
|
|
||||||
purposes.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"uuid": str(self.uuid),
|
|
||||||
"message": {
|
|
||||||
"__class__": self.message.__class__.__name__,
|
|
||||||
**asdict(self.message),
|
|
||||||
},
|
|
||||||
"role_at_backend": self.role_at_backend,
|
|
||||||
"extra_info": self.extra_info,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_openai_message(self) -> OpenAIMessage:
|
|
||||||
r"""Converts the record to an :obj:`OpenAIMessage` object."""
|
|
||||||
return self.message.to_openai_message(self.role_at_backend)
|
|
||||||
|
|
||||||
|
|
||||||
class ContextRecord(BaseModel):
|
|
||||||
r"""The result of memory retrieving."""
|
|
||||||
|
|
||||||
memory_record: MemoryRecord
|
|
||||||
score: float
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from camel.types import (
|
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .conversion import (
|
|
||||||
AlpacaItem,
|
|
||||||
HermesFunctionFormatter,
|
|
||||||
ShareGPTMessage,
|
|
||||||
)
|
|
||||||
from .conversion.conversation_models import (
|
|
||||||
ShareGPTConversation,
|
|
||||||
)
|
|
||||||
from .conversion.sharegpt.function_call_formatter import (
|
|
||||||
FunctionCallFormatter,
|
|
||||||
)
|
|
||||||
|
|
||||||
OpenAISystemMessage = ChatCompletionSystemMessageParam
|
|
||||||
OpenAIAssistantMessage = Union[
|
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
|
||||||
]
|
|
||||||
OpenAIUserMessage = ChatCompletionUserMessageParam
|
|
||||||
OpenAIToolMessageParam = ChatCompletionToolMessageParam
|
|
||||||
|
|
||||||
OpenAIMessage = ChatCompletionMessageParam
|
|
||||||
|
|
||||||
|
|
||||||
from .base import BaseMessage # noqa: E402
|
|
||||||
from .func_message import FunctionCallingMessage # noqa: E402
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'OpenAISystemMessage',
|
|
||||||
'OpenAIAssistantMessage',
|
|
||||||
'OpenAIUserMessage',
|
|
||||||
'OpenAIToolMessageParam',
|
|
||||||
'OpenAIMessage',
|
|
||||||
'FunctionCallFormatter',
|
|
||||||
'HermesFunctionFormatter',
|
|
||||||
'ShareGPTConversation',
|
|
||||||
'ShareGPTMessage',
|
|
||||||
'BaseMessage',
|
|
||||||
'FunctionCallingMessage',
|
|
||||||
'AlpacaItem',
|
|
||||||
]
|
|
||||||
@@ -1,541 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from camel.messages import (
|
|
||||||
FunctionCallFormatter,
|
|
||||||
HermesFunctionFormatter,
|
|
||||||
OpenAIAssistantMessage,
|
|
||||||
OpenAIMessage,
|
|
||||||
OpenAISystemMessage,
|
|
||||||
OpenAIUserMessage,
|
|
||||||
)
|
|
||||||
from camel.messages.conversion import ShareGPTMessage
|
|
||||||
from camel.prompts import CodePrompt, TextPrompt
|
|
||||||
from camel.types import (
|
|
||||||
OpenAIBackendRole,
|
|
||||||
OpenAIImageType,
|
|
||||||
OpenAIVisionDetailType,
|
|
||||||
RoleType,
|
|
||||||
)
|
|
||||||
from camel.utils import Constants
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseMessage:
|
|
||||||
r"""Base class for message objects used in CAMEL chat system.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_name (str): The name of the user or assistant role.
|
|
||||||
role_type (RoleType): The type of role, either :obj:`RoleType.
|
|
||||||
ASSISTANT` or :obj:`RoleType.USER`.
|
|
||||||
meta_dict (Optional[Dict[str, str]]): Additional metadata dictionary
|
|
||||||
for the message.
|
|
||||||
content (str): The content of the message.
|
|
||||||
video_bytes (Optional[bytes]): Optional bytes of a video associated
|
|
||||||
with the message. (default: :obj:`None`)
|
|
||||||
image_list (Optional[List[Image.Image]]): Optional list of PIL Image
|
|
||||||
objects associated with the message. (default: :obj:`None`)
|
|
||||||
image_detail (Literal["auto", "low", "high"]): Detail level of the
|
|
||||||
images associated with the message. (default: :obj:`auto`)
|
|
||||||
video_detail (Literal["auto", "low", "high"]): Detail level of the
|
|
||||||
videos associated with the message. (default: :obj:`low`)
|
|
||||||
parsed: Optional[Union[Type[BaseModel], dict]]: Optional object which
|
|
||||||
is parsed from the content. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
role_name: str
|
|
||||||
role_type: RoleType
|
|
||||||
meta_dict: Optional[Dict[str, Any]]
|
|
||||||
content: str
|
|
||||||
|
|
||||||
video_bytes: Optional[bytes] = None
|
|
||||||
image_list: Optional[List[Image.Image]] = None
|
|
||||||
image_detail: Literal["auto", "low", "high"] = "auto"
|
|
||||||
video_detail: Literal["auto", "low", "high"] = "low"
|
|
||||||
parsed: Optional[Union[Type[BaseModel], dict]] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_user_message(
|
|
||||||
cls,
|
|
||||||
role_name: str,
|
|
||||||
content: str,
|
|
||||||
meta_dict: Optional[Dict[str, str]] = None,
|
|
||||||
video_bytes: Optional[bytes] = None,
|
|
||||||
image_list: Optional[List[Image.Image]] = None,
|
|
||||||
image_detail: Union[
|
|
||||||
OpenAIVisionDetailType, str
|
|
||||||
] = OpenAIVisionDetailType.AUTO,
|
|
||||||
video_detail: Union[
|
|
||||||
OpenAIVisionDetailType, str
|
|
||||||
] = OpenAIVisionDetailType.LOW,
|
|
||||||
) -> "BaseMessage":
|
|
||||||
r"""Create a new user message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_name (str): The name of the user role.
|
|
||||||
content (str): The content of the message.
|
|
||||||
meta_dict (Optional[Dict[str, str]]): Additional metadata
|
|
||||||
dictionary for the message.
|
|
||||||
video_bytes (Optional[bytes]): Optional bytes of a video
|
|
||||||
associated with the message.
|
|
||||||
image_list (Optional[List[Image.Image]]): Optional list of PIL
|
|
||||||
Image objects associated with the message.
|
|
||||||
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
|
||||||
the images associated with the message.
|
|
||||||
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
|
||||||
the videos associated with the message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessage: The new user message.
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
role_name,
|
|
||||||
RoleType.USER,
|
|
||||||
meta_dict,
|
|
||||||
content,
|
|
||||||
video_bytes,
|
|
||||||
image_list,
|
|
||||||
OpenAIVisionDetailType(image_detail).value,
|
|
||||||
OpenAIVisionDetailType(video_detail).value,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_assistant_message(
|
|
||||||
cls,
|
|
||||||
role_name: str,
|
|
||||||
content: str,
|
|
||||||
meta_dict: Optional[Dict[str, str]] = None,
|
|
||||||
video_bytes: Optional[bytes] = None,
|
|
||||||
image_list: Optional[List[Image.Image]] = None,
|
|
||||||
image_detail: Union[
|
|
||||||
OpenAIVisionDetailType, str
|
|
||||||
] = OpenAIVisionDetailType.AUTO,
|
|
||||||
video_detail: Union[
|
|
||||||
OpenAIVisionDetailType, str
|
|
||||||
] = OpenAIVisionDetailType.LOW,
|
|
||||||
) -> "BaseMessage":
|
|
||||||
r"""Create a new assistant message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_name (str): The name of the assistant role.
|
|
||||||
content (str): The content of the message.
|
|
||||||
meta_dict (Optional[Dict[str, str]]): Additional metadata
|
|
||||||
dictionary for the message.
|
|
||||||
video_bytes (Optional[bytes]): Optional bytes of a video
|
|
||||||
associated with the message.
|
|
||||||
image_list (Optional[List[Image.Image]]): Optional list of PIL
|
|
||||||
Image objects associated with the message.
|
|
||||||
image_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
|
||||||
the images associated with the message.
|
|
||||||
video_detail (Union[OpenAIVisionDetailType, str]): Detail level of
|
|
||||||
the videos associated with the message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessage: The new assistant message.
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
role_name,
|
|
||||||
RoleType.ASSISTANT,
|
|
||||||
meta_dict,
|
|
||||||
content,
|
|
||||||
video_bytes,
|
|
||||||
image_list,
|
|
||||||
OpenAIVisionDetailType(image_detail).value,
|
|
||||||
OpenAIVisionDetailType(video_detail).value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_new_instance(self, content: str) -> "BaseMessage":
|
|
||||||
r"""Create a new instance of the :obj:`BaseMessage` with updated
|
|
||||||
content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content (str): The new content value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessage: The new instance of :obj:`BaseMessage`.
|
|
||||||
"""
|
|
||||||
return self.__class__(
|
|
||||||
role_name=self.role_name,
|
|
||||||
role_type=self.role_type,
|
|
||||||
meta_dict=self.meta_dict,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> Union["BaseMessage", Any]:
|
|
||||||
r"""Addition operator override for :obj:`BaseMessage`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other (Any): The value to be added with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[BaseMessage, Any]: The result of the addition.
|
|
||||||
"""
|
|
||||||
if isinstance(other, BaseMessage):
|
|
||||||
combined_content = self.content.__add__(other.content)
|
|
||||||
elif isinstance(other, str):
|
|
||||||
combined_content = self.content.__add__(other)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"Unsupported operand type(s) for +: '{type(self)}' and "
|
|
||||||
f"'{type(other)}'"
|
|
||||||
)
|
|
||||||
return self.create_new_instance(combined_content)
|
|
||||||
|
|
||||||
def __mul__(self, other: Any) -> Union["BaseMessage", Any]:
|
|
||||||
r"""Multiplication operator override for :obj:`BaseMessage`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other (Any): The value to be multiplied with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[BaseMessage, Any]: The result of the multiplication.
|
|
||||||
"""
|
|
||||||
if isinstance(other, int):
|
|
||||||
multiplied_content = self.content.__mul__(other)
|
|
||||||
return self.create_new_instance(multiplied_content)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"Unsupported operand type(s) for *: '{type(self)}' and "
|
|
||||||
f"'{type(other)}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
r"""Length operator override for :obj:`BaseMessage`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The length of the content.
|
|
||||||
"""
|
|
||||||
return len(self.content)
|
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
|
||||||
r"""Contains operator override for :obj:`BaseMessage`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item (str): The item to check for containment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: :obj:`True` if the item is contained in the content,
|
|
||||||
:obj:`False` otherwise.
|
|
||||||
"""
|
|
||||||
return item in self.content
|
|
||||||
|
|
||||||
def extract_text_and_code_prompts(
|
|
||||||
self,
|
|
||||||
) -> Tuple[List[TextPrompt], List[CodePrompt]]:
|
|
||||||
r"""Extract text and code prompts from the message content.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[TextPrompt], List[CodePrompt]]: A tuple containing a
|
|
||||||
list of text prompts and a list of code prompts extracted
|
|
||||||
from the content.
|
|
||||||
"""
|
|
||||||
text_prompts: List[TextPrompt] = []
|
|
||||||
code_prompts: List[CodePrompt] = []
|
|
||||||
|
|
||||||
lines = self.content.split("\n")
|
|
||||||
idx = 0
|
|
||||||
start_idx = 0
|
|
||||||
while idx < len(lines):
|
|
||||||
while idx < len(lines) and (
|
|
||||||
not lines[idx].lstrip().startswith("```")
|
|
||||||
):
|
|
||||||
idx += 1
|
|
||||||
text = "\n".join(lines[start_idx:idx]).strip()
|
|
||||||
text_prompts.append(TextPrompt(text))
|
|
||||||
|
|
||||||
if idx >= len(lines):
|
|
||||||
break
|
|
||||||
|
|
||||||
code_type = lines[idx].strip()[3:].strip()
|
|
||||||
idx += 1
|
|
||||||
start_idx = idx
|
|
||||||
while not lines[idx].lstrip().startswith("```"):
|
|
||||||
idx += 1
|
|
||||||
code = "\n".join(lines[start_idx:idx]).strip()
|
|
||||||
code_prompts.append(CodePrompt(code, code_type=code_type))
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
start_idx = idx
|
|
||||||
|
|
||||||
return text_prompts, code_prompts
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_sharegpt(
|
|
||||||
cls,
|
|
||||||
message: ShareGPTMessage,
|
|
||||||
function_format: Optional[FunctionCallFormatter[Any, Any]] = None,
|
|
||||||
role_mapping=None,
|
|
||||||
) -> "BaseMessage":
|
|
||||||
r"""Convert ShareGPT message to BaseMessage or FunctionCallingMessage.
|
|
||||||
Note tool calls and responses have an 'assistant' role in CAMEL
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (ShareGPTMessage): ShareGPT message to convert.
|
|
||||||
function_format (FunctionCallFormatter, optional): Function call
|
|
||||||
formatter to use. (default: :obj:`HermesFunctionFormatter()`.
|
|
||||||
role_mapping (Dict[str, List[str, RoleType]], optional): Role
|
|
||||||
mapping to use. Defaults to a CAMEL specific mapping.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessage: Converted message.
|
|
||||||
"""
|
|
||||||
from camel.messages import FunctionCallingMessage
|
|
||||||
|
|
||||||
if role_mapping is None:
|
|
||||||
role_mapping = {
|
|
||||||
"system": ["system", RoleType.USER],
|
|
||||||
"human": ["user", RoleType.USER],
|
|
||||||
"gpt": ["assistant", RoleType.ASSISTANT],
|
|
||||||
"tool": ["assistant", RoleType.ASSISTANT],
|
|
||||||
}
|
|
||||||
role_name, role_type = role_mapping[message.from_]
|
|
||||||
|
|
||||||
if function_format is None:
|
|
||||||
function_format = HermesFunctionFormatter()
|
|
||||||
|
|
||||||
# Check if this is a function-related message
|
|
||||||
if message.from_ == "gpt":
|
|
||||||
func_info = function_format.extract_tool_calls(message.value)
|
|
||||||
if (
|
|
||||||
func_info and len(func_info) == 1
|
|
||||||
): # TODO: Handle multiple tool calls
|
|
||||||
# Including cleaned content is useful to
|
|
||||||
# remind consumers of non-considered content
|
|
||||||
clean_content = re.sub(
|
|
||||||
r"<tool_call>.*?</tool_call>",
|
|
||||||
"",
|
|
||||||
message.value,
|
|
||||||
flags=re.DOTALL,
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
return FunctionCallingMessage(
|
|
||||||
role_name=role_name,
|
|
||||||
role_type=role_type,
|
|
||||||
meta_dict=None,
|
|
||||||
content=clean_content,
|
|
||||||
func_name=func_info[0].__dict__["name"],
|
|
||||||
args=func_info[0].__dict__["arguments"],
|
|
||||||
)
|
|
||||||
elif message.from_ == "tool":
|
|
||||||
func_r_info = function_format.extract_tool_response(message.value)
|
|
||||||
if func_r_info:
|
|
||||||
return FunctionCallingMessage(
|
|
||||||
role_name=role_name,
|
|
||||||
role_type=role_type,
|
|
||||||
meta_dict=None,
|
|
||||||
content="",
|
|
||||||
func_name=func_r_info.__dict__["name"],
|
|
||||||
result=func_r_info.__dict__["content"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Regular message
|
|
||||||
return cls(
|
|
||||||
role_name=role_name,
|
|
||||||
role_type=role_type,
|
|
||||||
meta_dict=None,
|
|
||||||
content=message.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_sharegpt(
|
|
||||||
self,
|
|
||||||
function_format: Optional[FunctionCallFormatter] = None,
|
|
||||||
) -> ShareGPTMessage:
|
|
||||||
r"""Convert BaseMessage to ShareGPT message
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_format (FunctionCallFormatter): Function call formatter
|
|
||||||
to use. Defaults to Hermes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if function_format is None:
|
|
||||||
function_format = HermesFunctionFormatter()
|
|
||||||
|
|
||||||
# Convert role type to ShareGPT 'from' field
|
|
||||||
if self.role_type == RoleType.USER:
|
|
||||||
from_ = "system" if self.role_name == "system" else "human"
|
|
||||||
else: # RoleType.ASSISTANT
|
|
||||||
from_ = "gpt"
|
|
||||||
|
|
||||||
# Function conversion code in FunctionCallingMessage
|
|
||||||
return ShareGPTMessage(from_=from_, value=self.content) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
def to_openai_message(
|
|
||||||
self,
|
|
||||||
role_at_backend: OpenAIBackendRole,
|
|
||||||
) -> OpenAIMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIMessage` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_at_backend (OpenAIBackendRole): The role of the message in
|
|
||||||
OpenAI chat system.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
|
|
||||||
"""
|
|
||||||
if role_at_backend == OpenAIBackendRole.SYSTEM:
|
|
||||||
return self.to_openai_system_message()
|
|
||||||
elif role_at_backend == OpenAIBackendRole.USER:
|
|
||||||
return self.to_openai_user_message()
|
|
||||||
elif role_at_backend == OpenAIBackendRole.ASSISTANT:
|
|
||||||
return self.to_openai_assistant_message()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported role: {role_at_backend}.")
|
|
||||||
|
|
||||||
def to_openai_system_message(self) -> OpenAISystemMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAISystemMessage` object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAISystemMessage: The converted :obj:`OpenAISystemMessage`
|
|
||||||
object.
|
|
||||||
"""
|
|
||||||
return {"role": "system", "content": self.content}
|
|
||||||
|
|
||||||
def to_openai_user_message(self) -> OpenAIUserMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIUserMessage` object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIUserMessage: The converted :obj:`OpenAIUserMessage` object.
|
|
||||||
"""
|
|
||||||
hybird_content: List[Any] = []
|
|
||||||
hybird_content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": self.content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.image_list and len(self.image_list) > 0:
|
|
||||||
for image in self.image_list:
|
|
||||||
if image.format is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image's `format` is `None`, please "
|
|
||||||
f"transform the `PIL.Image.Image` to one of "
|
|
||||||
f"following supported formats, such as "
|
|
||||||
f"{list(OpenAIImageType)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
image_type: str = image.format.lower()
|
|
||||||
if image_type not in OpenAIImageType:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image type {image.format} "
|
|
||||||
f"is not supported by OpenAI vision model"
|
|
||||||
)
|
|
||||||
with io.BytesIO() as buffer:
|
|
||||||
image.save(fp=buffer, format=image.format)
|
|
||||||
encoded_image = base64.b64encode(buffer.getvalue()).decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
image_prefix = f"data:image/{image_type};base64,"
|
|
||||||
hybird_content.append(
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"{image_prefix}{encoded_image}",
|
|
||||||
"detail": self.image_detail,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.video_bytes:
|
|
||||||
import imageio.v3 as iio
|
|
||||||
|
|
||||||
base64Frames: List[str] = []
|
|
||||||
frame_count = 0
|
|
||||||
# read video bytes
|
|
||||||
video = iio.imiter(
|
|
||||||
self.video_bytes, plugin=Constants.VIDEO_DEFAULT_PLUG_PYAV
|
|
||||||
)
|
|
||||||
|
|
||||||
for frame in video:
|
|
||||||
frame_count += 1
|
|
||||||
if (
|
|
||||||
frame_count % Constants.VIDEO_IMAGE_EXTRACTION_INTERVAL
|
|
||||||
== 0
|
|
||||||
):
|
|
||||||
# convert frame to numpy array
|
|
||||||
frame_array = np.asarray(frame)
|
|
||||||
frame_image = Image.fromarray(frame_array)
|
|
||||||
|
|
||||||
# Get the dimensions of the frame
|
|
||||||
width, height = frame_image.size
|
|
||||||
|
|
||||||
# resize the frame to the default image size
|
|
||||||
new_width = Constants.VIDEO_DEFAULT_IMAGE_SIZE
|
|
||||||
aspect_ratio = width / height
|
|
||||||
new_height = int(new_width / aspect_ratio)
|
|
||||||
resized_img = frame_image.resize((new_width, new_height))
|
|
||||||
|
|
||||||
# encode the image to base64
|
|
||||||
with io.BytesIO() as buffer:
|
|
||||||
image_format = OpenAIImageType.JPEG.value
|
|
||||||
image_format = image_format.upper()
|
|
||||||
resized_img.save(fp=buffer, format=image_format)
|
|
||||||
encoded_image = base64.b64encode(
|
|
||||||
buffer.getvalue()
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
base64Frames.append(encoded_image)
|
|
||||||
|
|
||||||
for encoded_image in base64Frames:
|
|
||||||
item = {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
||||||
"detail": self.video_detail,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
hybird_content.append(item)
|
|
||||||
|
|
||||||
if len(hybird_content) > 1:
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": hybird_content,
|
|
||||||
}
|
|
||||||
# This return just for str message
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": self.content,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
|
|
||||||
object.
|
|
||||||
"""
|
|
||||||
return {"role": "assistant", "content": self.content}
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
r"""Converts the message to a dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The converted dictionary.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"role_name": self.role_name,
|
|
||||||
"role_type": self.role_type.name,
|
|
||||||
**(self.meta_dict or {}),
|
|
||||||
"content": self.content,
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .alpaca import AlpacaItem
|
|
||||||
from .conversation_models import (
|
|
||||||
ShareGPTConversation,
|
|
||||||
ShareGPTMessage,
|
|
||||||
ToolCall,
|
|
||||||
ToolResponse,
|
|
||||||
)
|
|
||||||
from .sharegpt import HermesFunctionFormatter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'ShareGPTMessage',
|
|
||||||
'ShareGPTConversation',
|
|
||||||
'HermesFunctionFormatter',
|
|
||||||
'AlpacaItem',
|
|
||||||
'ToolCall',
|
|
||||||
'ToolResponse',
|
|
||||||
]
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class AlpacaItem(BaseModel):
|
|
||||||
r"""Represents an instruction-response item in the Alpaca format.
|
|
||||||
|
|
||||||
Appropripate for both cases where input field is empty, or populated.
|
|
||||||
Provides parsing from string format using the class method from_string().
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instruction (str): The instruction/question/prompt
|
|
||||||
input (str): Input context or examples (put empty string if none)
|
|
||||||
output (str): The response/answer to the instruction
|
|
||||||
"""
|
|
||||||
|
|
||||||
instruction: str = Field(description="The instruction/question/prompt")
|
|
||||||
input: str = Field(
|
|
||||||
description="Optional context or input for the task."
|
|
||||||
" For example, when the instruction is \"Summarize the "
|
|
||||||
"following article\", the input is the article."
|
|
||||||
)
|
|
||||||
output: str = Field(description="The response/answer to the instruction")
|
|
||||||
|
|
||||||
@field_validator('instruction', 'output')
|
|
||||||
def no_section_markers(cls, value: str) -> str:
|
|
||||||
r"""Ensures fields don't contain section markers like '###
|
|
||||||
Response:'
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
'### Response' in value
|
|
||||||
or '### Instruction' in value
|
|
||||||
or '### Input' in value
|
|
||||||
):
|
|
||||||
raise ValueError("Field cannot contain section markers")
|
|
||||||
return value.strip()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_string(cls, text: str) -> "AlpacaItem":
|
|
||||||
r"""Creates an AlpacaItem from a formatted string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: String in either of these formats:
|
|
||||||
With input:
|
|
||||||
### Instruction:
|
|
||||||
{instruction}
|
|
||||||
### Input:
|
|
||||||
{input}
|
|
||||||
### Response:
|
|
||||||
{response}
|
|
||||||
|
|
||||||
Without input:
|
|
||||||
### Instruction:
|
|
||||||
{instruction}
|
|
||||||
### Response:
|
|
||||||
{response}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AlpacaItem: Parsed instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: text doesn't match expected format or sections missing
|
|
||||||
"""
|
|
||||||
# Strip and standardize newlines
|
|
||||||
text = text.strip().replace('\r\n', '\n')
|
|
||||||
|
|
||||||
# Try to extract sections using regex
|
|
||||||
instruction_match = re.search(
|
|
||||||
r'###\s*Instruction:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
|
||||||
)
|
|
||||||
input_match = re.search(
|
|
||||||
r'###\s*Input:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
|
||||||
)
|
|
||||||
response_match = re.search(
|
|
||||||
r'###\s*Response:\s*\n(.+?)(?=\n###|\Z)', text, re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
if not instruction_match or not response_match:
|
|
||||||
raise ValueError(
|
|
||||||
"Text must contain '### Instruction:'"
|
|
||||||
" and '### Response:' sections"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
instruction=instruction_match.group(1).strip(),
|
|
||||||
input=input_match.group(1).strip() if input_match else "",
|
|
||||||
output=response_match.group(1).strip(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
r"""Converts the AlpacaItem to its string representation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Formatted string representation with sections markers
|
|
||||||
"""
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
"### Instruction:",
|
|
||||||
self.instruction,
|
|
||||||
"",
|
|
||||||
"### Input:",
|
|
||||||
self.input,
|
|
||||||
"",
|
|
||||||
"### Response:",
|
|
||||||
self.output,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Dict, List, Literal
|
|
||||||
|
|
||||||
from pydantic import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
RootModel,
|
|
||||||
field_validator,
|
|
||||||
model_validator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTMessage(BaseModel):
|
|
||||||
r"""A single message in ShareGPT format with enhanced validation"""
|
|
||||||
|
|
||||||
from_: Literal["human", "gpt", "system", "tool"] = Field(
|
|
||||||
alias="from", description="The role of the message sender"
|
|
||||||
)
|
|
||||||
value: str = Field(
|
|
||||||
min_length=0,
|
|
||||||
max_length=100000,
|
|
||||||
description="The content of the message",
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = {
|
|
||||||
"populate_by_name": True,
|
|
||||||
"extra": "forbid",
|
|
||||||
"json_schema_extra": {
|
|
||||||
"examples": [
|
|
||||||
{"from": "human", "value": "What's the weather like today?"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTConversation(RootModel):
|
|
||||||
r"""A full conversation in ShareGPT format with validation"""
|
|
||||||
|
|
||||||
root: List[ShareGPTMessage]
|
|
||||||
|
|
||||||
@model_validator(mode='after')
|
|
||||||
def validate_conversation_flow(self) -> 'ShareGPTConversation':
|
|
||||||
r"""Validate the conversation follows logical message order"""
|
|
||||||
messages = self.root
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
raise ValueError("Conversation cannot be empty")
|
|
||||||
|
|
||||||
if messages[0].from_ not in ("system", "human"):
|
|
||||||
raise ValueError(
|
|
||||||
"Conversation must start with either system or human message"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate message sequence
|
|
||||||
for i in range(1, len(messages)):
|
|
||||||
curr, prev = messages[i], messages[i - 1]
|
|
||||||
|
|
||||||
if curr.from_ == "tool":
|
|
||||||
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tool response at position {i} "
|
|
||||||
f"must follow an gpt message with a tool call"
|
|
||||||
)
|
|
||||||
|
|
||||||
if curr.from_ == "gpt" and prev.from_ not in (
|
|
||||||
"human",
|
|
||||||
"tool",
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Assistant message at position {i} "
|
|
||||||
f"must follow a human or tool message"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def model_dump(self, **kwargs):
|
|
||||||
return self.root
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.root)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
|
||||||
r"""Represents a single tool/function call with validation"""
|
|
||||||
|
|
||||||
name: str = Field(
|
|
||||||
min_length=1,
|
|
||||||
max_length=256,
|
|
||||||
description="The name of the tool to call",
|
|
||||||
)
|
|
||||||
arguments: Dict[str, Any] = Field(
|
|
||||||
description="The arguments to pass to the tool"
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator('arguments')
|
|
||||||
@classmethod
|
|
||||||
def validate_arguments(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
r"""Validate argument structure and content"""
|
|
||||||
|
|
||||||
# Try to serialize arguments to ensure they're JSON-compatible
|
|
||||||
try:
|
|
||||||
json.dumps(v)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise ValueError("Arguments must be JSON-serializable")
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
model_config = {
|
|
||||||
"extra": "forbid",
|
|
||||||
"json_schema_extra": {
|
|
||||||
"examples": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"arguments": {"city": "London", "units": "celsius"},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ToolResponse(BaseModel):
|
|
||||||
r"""Represents a tool/function response with validation. This is a
|
|
||||||
base class and default implementation for tool responses, for the purpose
|
|
||||||
of converting between different formats.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = Field(
|
|
||||||
min_length=1,
|
|
||||||
max_length=256,
|
|
||||||
description="The name of the tool that was called",
|
|
||||||
)
|
|
||||||
content: Any = Field(
|
|
||||||
description="The response content from the tool."
|
|
||||||
" Must be JSON serializable literal or object"
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator('content')
|
|
||||||
@classmethod
|
|
||||||
def validate_content(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
r"""Validate response content structure"""
|
|
||||||
|
|
||||||
# Ensure content is JSON-serializable
|
|
||||||
try:
|
|
||||||
json.dumps(v)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise ValueError("Response content must be JSON-serializable")
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
model_config = {
|
|
||||||
"extra": "forbid",
|
|
||||||
"json_schema_extra": {
|
|
||||||
"examples": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"content": {
|
|
||||||
"temperature": 20,
|
|
||||||
"conditions": "sunny",
|
|
||||||
"humidity": 65,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
|
|
||||||
from .hermes import HermesFunctionFormatter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'HermesFunctionFormatter',
|
|
||||||
]
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
|
||||||
|
|
||||||
from camel.messages.conversion import (
|
|
||||||
ToolCall,
|
|
||||||
ToolResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
CallT = TypeVar('CallT', bound=ToolCall, covariant=True)
|
|
||||||
ResponseT = TypeVar('ResponseT', bound=ToolResponse, covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionCallFormatter(ABC, Generic[CallT, ResponseT]):
|
|
||||||
r"""Abstract base class for function calling formats"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def extract_tool_calls(self, message: str) -> List[CallT]:
|
|
||||||
r"""Extract function call info from a message string"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def extract_tool_response(self, message: str) -> Optional[ResponseT]:
|
|
||||||
r"""Extract function response info from a message string"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_tool_call(
|
|
||||||
self, content: str, func_name: str, args: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
r"""Format a function call into a message string"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_tool_response(self, func_name: str, result: Any) -> str:
|
|
||||||
r"""Format a function response into a message string"""
|
|
||||||
pass
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
from .hermes_function_formatter import HermesFunctionFormatter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'HermesFunctionFormatter',
|
|
||||||
]
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from camel.messages.conversion import (
|
|
||||||
ToolCall,
|
|
||||||
ToolResponse,
|
|
||||||
)
|
|
||||||
from camel.messages.conversion.sharegpt.function_call_formatter import (
|
|
||||||
FunctionCallFormatter,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HermesToolResponse(ToolResponse):
|
|
||||||
r"""Represents a single tool/function call with validation"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HermesToolCall(ToolCall):
|
|
||||||
r"""Represents a single tool/function call with validation"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HermesFunctionFormatter(
|
|
||||||
FunctionCallFormatter[HermesToolCall, HermesToolResponse]
|
|
||||||
):
|
|
||||||
r"""Hermes-style function calling format implementation with validation"""
|
|
||||||
|
|
||||||
def extract_tool_calls(self, message: str) -> List[HermesToolCall]:
|
|
||||||
r"""Extracts all tool calls from the provided message string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): The input message string containing potential tool
|
|
||||||
calls.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[HermesToolCall]: A list of parsed HermesToolCall objects.
|
|
||||||
"""
|
|
||||||
tool_calls = []
|
|
||||||
pattern = r"<tool_call>\s*({.*?})\s*</tool_call>"
|
|
||||||
matches = re.finditer(pattern, message, re.DOTALL)
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
try:
|
|
||||||
call_dict = json.loads(match.group(1).replace("'", '"'))
|
|
||||||
tool_calls.append(HermesToolCall.model_validate(call_dict))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to parse tool call: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
def extract_tool_response(
|
|
||||||
self, message: str
|
|
||||||
) -> Optional[HermesToolResponse]:
|
|
||||||
r"""Extracts a single tool response from the provided message string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): The input message string containing a potential
|
|
||||||
tool response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[HermesToolResponse]: A parsed HermesToolResponse object,
|
|
||||||
or None if no valid response is found.
|
|
||||||
"""
|
|
||||||
pattern = r"<tool_response>\s*({.*?})\s*</tool_response>"
|
|
||||||
match = re.search(pattern, message, re.DOTALL)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
response_json = match.group(1)
|
|
||||||
response_dict = json.loads(response_json.replace("'", '"'))
|
|
||||||
return HermesToolResponse.model_validate(response_dict)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to parse tool response: {e}")
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def format_tool_call(
|
|
||||||
self, content: str, func_name: str, args: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
r"""Formats a tool call message with the given content, function name,
|
|
||||||
and arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content (str): The content or message to be included in the tool
|
|
||||||
call.
|
|
||||||
func_name (str): The name of the function being called.
|
|
||||||
args (Dict[str, Any]): A dictionary of arguments to be passed to
|
|
||||||
the function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A formatted string representing the tool call in Hermes
|
|
||||||
format.
|
|
||||||
"""
|
|
||||||
tool_call_dict = {"name": func_name, "arguments": args}
|
|
||||||
return f"{content}\n<tool_call>\n{tool_call_dict}\n</tool_call>"
|
|
||||||
|
|
||||||
def format_tool_response(self, func_name: str, result: Any) -> str:
|
|
||||||
r"""Formats a tool response message with the given function name and
|
|
||||||
result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func_name (str): The name of the function whose result is being
|
|
||||||
returned.
|
|
||||||
result (Any): The result to be included in the tool response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A formatted string representing the tool response in Hermes
|
|
||||||
format.
|
|
||||||
"""
|
|
||||||
response_dict = {"name": func_name, "content": result}
|
|
||||||
return f"<tool_response>\n{response_dict}\n</tool_response>"
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from camel.messages import (
|
|
||||||
BaseMessage,
|
|
||||||
HermesFunctionFormatter,
|
|
||||||
OpenAIAssistantMessage,
|
|
||||||
OpenAIMessage,
|
|
||||||
OpenAIToolMessageParam,
|
|
||||||
)
|
|
||||||
from camel.messages.conversion import (
|
|
||||||
ShareGPTMessage,
|
|
||||||
ToolCall,
|
|
||||||
ToolResponse,
|
|
||||||
)
|
|
||||||
from camel.messages.conversion.sharegpt.function_call_formatter import (
|
|
||||||
FunctionCallFormatter,
|
|
||||||
)
|
|
||||||
from camel.types import OpenAIBackendRole
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FunctionCallingMessage(BaseMessage):
|
|
||||||
r"""Class for message objects used specifically for
|
|
||||||
function-related messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func_name (Optional[str]): The name of the function used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
args (Optional[Dict]): The dictionary of arguments passed to the
|
|
||||||
function. (default: :obj:`None`)
|
|
||||||
result (Optional[Any]): The result of function execution.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
tool_call_id (Optional[str]): The ID of the tool call, if available.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
func_name: Optional[str] = None
|
|
||||||
args: Optional[Dict] = None
|
|
||||||
result: Optional[Any] = None
|
|
||||||
tool_call_id: Optional[str] = None
|
|
||||||
|
|
||||||
def to_openai_message(
|
|
||||||
self,
|
|
||||||
role_at_backend: OpenAIBackendRole,
|
|
||||||
) -> OpenAIMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIMessage` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
role_at_backend (OpenAIBackendRole): The role of the message in
|
|
||||||
OpenAI chat system.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
|
|
||||||
"""
|
|
||||||
if role_at_backend == OpenAIBackendRole.ASSISTANT:
|
|
||||||
return self.to_openai_assistant_message()
|
|
||||||
elif role_at_backend == OpenAIBackendRole.FUNCTION:
|
|
||||||
return self.to_openai_tool_message()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported role: {role_at_backend}.")
|
|
||||||
|
|
||||||
def to_sharegpt(
|
|
||||||
self,
|
|
||||||
function_format: Optional[
|
|
||||||
FunctionCallFormatter[ToolCall, ToolResponse]
|
|
||||||
] = None,
|
|
||||||
) -> ShareGPTMessage:
|
|
||||||
r"""Convert FunctionCallingMessage to ShareGPT message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_format (FunctionCallFormatter[ToolCall, ToolResponse],
|
|
||||||
optional): The function formatter to use. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if function_format is None:
|
|
||||||
function_format = HermesFunctionFormatter()
|
|
||||||
# The role of the message is an unreliable indicator of whether
|
|
||||||
# it is a function call or response, so use result
|
|
||||||
if self.result is None:
|
|
||||||
# This is a function call
|
|
||||||
# TODO: split the incoming types to be more specific
|
|
||||||
# and remove the type ignores
|
|
||||||
content = function_format.format_tool_call(
|
|
||||||
self.content or "", # type: ignore[arg-type]
|
|
||||||
self.func_name, # type: ignore[arg-type]
|
|
||||||
self.args, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
return ShareGPTMessage(from_="gpt", value=content) # type: ignore[call-arg]
|
|
||||||
else:
|
|
||||||
# This is a function response
|
|
||||||
# TODO: Allow for more flexible setting of tool role,
|
|
||||||
# optionally to be the same as assistant messages
|
|
||||||
content = function_format.format_tool_response(
|
|
||||||
self.func_name, # type: ignore[arg-type]
|
|
||||||
self.result, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
return ShareGPTMessage(from_="tool", value=content) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIAssistantMessage` object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIAssistantMessage: The converted :obj:`OpenAIAssistantMessage`
|
|
||||||
object.
|
|
||||||
"""
|
|
||||||
if (not self.func_name) or (self.args is None):
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid request for converting into assistant message"
|
|
||||||
" due to missing function name or arguments."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": self.content or "",
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": self.tool_call_id or "null",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": self.func_name,
|
|
||||||
"arguments": json.dumps(self.args),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
|
|
||||||
r"""Converts the message to an :obj:`OpenAIToolMessageParam` object
|
|
||||||
with the role being "tool".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpenAIToolMessageParam: The converted
|
|
||||||
:obj:`OpenAIToolMessageParam` object with its role being
|
|
||||||
"tool".
|
|
||||||
"""
|
|
||||||
if not self.func_name:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid request for converting into function message"
|
|
||||||
" due to missing function name."
|
|
||||||
)
|
|
||||||
|
|
||||||
result_content = str(self.result)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"content": result_content,
|
|
||||||
"tool_call_id": self.tool_call_id or "null",
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from .anthropic_model import AnthropicModel
|
|
||||||
from .azure_openai_model import AzureOpenAIModel
|
|
||||||
from .base_model import BaseModelBackend
|
|
||||||
from .cohere_model import CohereModel
|
|
||||||
from .deepseek_model import DeepSeekModel
|
|
||||||
from .gemini_model import GeminiModel
|
|
||||||
from .groq_model import GroqModel
|
|
||||||
from .litellm_model import LiteLLMModel
|
|
||||||
from .mistral_model import MistralModel
|
|
||||||
from .model_factory import ModelFactory
|
|
||||||
from .model_manager import ModelManager, ModelProcessingError
|
|
||||||
from .nemotron_model import NemotronModel
|
|
||||||
from .nvidia_model import NvidiaModel
|
|
||||||
from .ollama_model import OllamaModel
|
|
||||||
from .openai_audio_models import OpenAIAudioModels
|
|
||||||
from .openai_compatible_model import OpenAICompatibleModel
|
|
||||||
from .openai_model import OpenAIModel
|
|
||||||
from .qwen_model import QwenModel
|
|
||||||
from .reka_model import RekaModel
|
|
||||||
from .samba_model import SambaModel
|
|
||||||
from .stub_model import StubModel
|
|
||||||
from .togetherai_model import TogetherAIModel
|
|
||||||
from .vllm_model import VLLMModel
|
|
||||||
from .yi_model import YiModel
|
|
||||||
from .zhipuai_model import ZhipuAIModel
|
|
||||||
from .fish_audio_model import FishAudioModel
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseModelBackend',
|
|
||||||
'OpenAIModel',
|
|
||||||
'AzureOpenAIModel',
|
|
||||||
'AnthropicModel',
|
|
||||||
'MistralModel',
|
|
||||||
'GroqModel',
|
|
||||||
'StubModel',
|
|
||||||
'ZhipuAIModel',
|
|
||||||
'CohereModel',
|
|
||||||
'ModelFactory',
|
|
||||||
'ModelManager',
|
|
||||||
'LiteLLMModel',
|
|
||||||
'OpenAIAudioModels',
|
|
||||||
'NemotronModel',
|
|
||||||
'NvidiaModel',
|
|
||||||
'OllamaModel',
|
|
||||||
'VLLMModel',
|
|
||||||
'GeminiModel',
|
|
||||||
'OpenAICompatibleModel',
|
|
||||||
'RekaModel',
|
|
||||||
'SambaModel',
|
|
||||||
'TogetherAIModel',
|
|
||||||
'YiModel',
|
|
||||||
'QwenModel',
|
|
||||||
'ModelProcessingError',
|
|
||||||
'DeepSeekModel',
|
|
||||||
]
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from camel.configs import ANTHROPIC_API_PARAMS, AnthropicConfig
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.models.base_model import BaseModelBackend
|
|
||||||
from camel.types import ChatCompletion, ModelType
|
|
||||||
from camel.utils import (
|
|
||||||
AnthropicTokenCounter,
|
|
||||||
BaseTokenCounter,
|
|
||||||
api_keys_required,
|
|
||||||
dependencies_required,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicModel(BaseModelBackend):
|
|
||||||
r"""Anthropic API in a unified BaseModelBackend interface.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (Union[ModelType, str]): Model for which a backend is
|
|
||||||
created, one of CLAUDE_* series.
|
|
||||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
|
||||||
that will be fed into Anthropic.messages.create(). If
|
|
||||||
:obj:`None`, :obj:`AnthropicConfig().as_dict()` will be used.
|
|
||||||
(default::obj:`None`)
|
|
||||||
api_key (Optional[str], optional): The API key for authenticating with
|
|
||||||
the Anthropic service. (default: :obj:`None`)
|
|
||||||
url (Optional[str], optional): The url to the Anthropic service.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
|
||||||
use for the model. If not provided, :obj:`AnthropicTokenCounter`
|
|
||||||
will be used. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dependencies_required('anthropic')
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: Union[ModelType, str],
|
|
||||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token_counter: Optional[BaseTokenCounter] = None,
|
|
||||||
) -> None:
|
|
||||||
from anthropic import Anthropic
|
|
||||||
|
|
||||||
if model_config_dict is None:
|
|
||||||
model_config_dict = AnthropicConfig().as_dict()
|
|
||||||
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
||||||
url = url or os.environ.get("ANTHROPIC_API_BASE_URL")
|
|
||||||
super().__init__(
|
|
||||||
model_type, model_config_dict, api_key, url, token_counter
|
|
||||||
)
|
|
||||||
self.client = Anthropic(api_key=self._api_key, base_url=self._url)
|
|
||||||
|
|
||||||
def _convert_response_from_anthropic_to_openai(self, response):
|
|
||||||
# openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
|
|
||||||
obj = ChatCompletion.construct(
|
|
||||||
id=None,
|
|
||||||
choices=[
|
|
||||||
dict(
|
|
||||||
index=0,
|
|
||||||
message={
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.content[0].text,
|
|
||||||
},
|
|
||||||
finish_reason=response.stop_reason,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=None,
|
|
||||||
model=response.model,
|
|
||||||
object="chat.completion",
|
|
||||||
)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
r"""Initialize the token counter for the model backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseTokenCounter: The token counter following the model's
|
|
||||||
tokenization style.
|
|
||||||
"""
|
|
||||||
if not self._token_counter:
|
|
||||||
self._token_counter = AnthropicTokenCounter()
|
|
||||||
return self._token_counter
|
|
||||||
|
|
||||||
def count_tokens_from_prompt(self, prompt: str) -> int:
|
|
||||||
r"""Count the number of tokens from a prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The number of tokens in the prompt.
|
|
||||||
"""
|
|
||||||
return self.client.count_tokens(prompt)
|
|
||||||
|
|
||||||
@api_keys_required("ANTHROPIC_API_KEY")
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
messages: List[OpenAIMessage],
|
|
||||||
):
|
|
||||||
r"""Run inference of Anthropic chat completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[OpenAIMessage]): Message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatCompletion: Response in the OpenAI API format.
|
|
||||||
"""
|
|
||||||
from anthropic import NOT_GIVEN
|
|
||||||
|
|
||||||
if messages[0]["role"] == "system":
|
|
||||||
sys_msg = str(messages.pop(0)["content"])
|
|
||||||
else:
|
|
||||||
sys_msg = NOT_GIVEN # type: ignore[assignment]
|
|
||||||
response = self.client.messages.create(
|
|
||||||
model=self.model_type,
|
|
||||||
system=sys_msg,
|
|
||||||
messages=messages, # type: ignore[arg-type]
|
|
||||||
**self.model_config_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# format response to openai format
|
|
||||||
response = self._convert_response_from_anthropic_to_openai(response)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def check_model_config(self):
|
|
||||||
r"""Check whether the model configuration is valid for anthropic
|
|
||||||
model backends.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model configuration dictionary contains any
|
|
||||||
unexpected arguments to OpenAI API, or it does not contain
|
|
||||||
:obj:`model_path` or :obj:`server_url`.
|
|
||||||
"""
|
|
||||||
for param in self.model_config_dict:
|
|
||||||
if param not in ANTHROPIC_API_PARAMS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected argument `{param}` is "
|
|
||||||
"input into Anthropic model backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
r"""Returns whether the model is in stream mode, which sends partial
|
|
||||||
results each time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether the model is in stream mode.
|
|
||||||
"""
|
|
||||||
return self.model_config_dict.get("stream", False)
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from openai import AzureOpenAI, Stream
|
|
||||||
|
|
||||||
from camel.configs import OPENAI_API_PARAMS, ChatGPTConfig
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.models.base_model import BaseModelBackend
|
|
||||||
from camel.types import (
|
|
||||||
ChatCompletion,
|
|
||||||
ChatCompletionChunk,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
|
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIModel(BaseModelBackend):
|
|
||||||
r"""Azure OpenAI API in a unified BaseModelBackend interface.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (Union[ModelType, str]): Model for which a backend is
|
|
||||||
created, one of GPT_* series.
|
|
||||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
|
||||||
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
|
||||||
:obj:`None`, :obj:`ChatGPTConfig().as_dict()` will be used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
api_key (Optional[str], optional): The API key for authenticating with
|
|
||||||
the OpenAI service. (default: :obj:`None`)
|
|
||||||
url (Optional[str], optional): The url to the OpenAI service.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
api_version (Optional[str], optional): The api version for the model.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
azure_deployment_name (Optional[str], optional): The deployment name
|
|
||||||
you chose when you deployed an azure model. (default: :obj:`None`)
|
|
||||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
|
||||||
use for the model. If not provided, :obj:`OpenAITokenCounter`
|
|
||||||
will be used. (default: :obj:`None`)
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://learn.microsoft.com/en-us/azure/ai-services/openai/
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: Union[ModelType, str],
|
|
||||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token_counter: Optional[BaseTokenCounter] = None,
|
|
||||||
api_version: Optional[str] = None,
|
|
||||||
azure_deployment_name: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
if model_config_dict is None:
|
|
||||||
model_config_dict = ChatGPTConfig().as_dict()
|
|
||||||
api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
|
|
||||||
url = url or os.environ.get("AZURE_OPENAI_BASE_URL")
|
|
||||||
super().__init__(
|
|
||||||
model_type, model_config_dict, api_key, url, token_counter
|
|
||||||
)
|
|
||||||
|
|
||||||
self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
|
|
||||||
self.azure_deployment_name = azure_deployment_name or os.environ.get(
|
|
||||||
"AZURE_DEPLOYMENT_NAME"
|
|
||||||
)
|
|
||||||
if self.api_version is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Must provide either the `api_version` argument "
|
|
||||||
"or `AZURE_API_VERSION` environment variable."
|
|
||||||
)
|
|
||||||
if self.azure_deployment_name is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Must provide either the `azure_deployment_name` argument "
|
|
||||||
"or `AZURE_DEPLOYMENT_NAME` environment variable."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._client = AzureOpenAI(
|
|
||||||
azure_endpoint=str(self._url),
|
|
||||||
azure_deployment=self.azure_deployment_name,
|
|
||||||
api_version=self.api_version,
|
|
||||||
api_key=self._api_key,
|
|
||||||
timeout=60,
|
|
||||||
max_retries=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
r"""Initialize the token counter for the model backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseTokenCounter: The token counter following the model's
|
|
||||||
tokenization style.
|
|
||||||
"""
|
|
||||||
if not self._token_counter:
|
|
||||||
self._token_counter = OpenAITokenCounter(self.model_type)
|
|
||||||
return self._token_counter
|
|
||||||
|
|
||||||
@api_keys_required("AZURE_OPENAI_API_KEY", "AZURE_API_VERSION")
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
messages: List[OpenAIMessage],
|
|
||||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
r"""Runs inference of Azure OpenAI chat completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[OpenAIMessage]): Message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
`ChatCompletion` in the non-stream mode, or
|
|
||||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
||||||
"""
|
|
||||||
response = self._client.chat.completions.create(
|
|
||||||
messages=messages,
|
|
||||||
model=self.azure_deployment_name, # type:ignore[arg-type]
|
|
||||||
**self.model_config_dict,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def check_model_config(self):
|
|
||||||
r"""Check whether the model configuration contains any
|
|
||||||
unexpected arguments to Azure OpenAI API.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model configuration dictionary contains any
|
|
||||||
unexpected arguments to Azure OpenAI API.
|
|
||||||
"""
|
|
||||||
for param in self.model_config_dict:
|
|
||||||
if param not in OPENAI_API_PARAMS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected argument `{param}` is "
|
|
||||||
"input into Azure OpenAI model backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
r"""Returns whether the model is in stream mode,
|
|
||||||
which sends partial results each time.
|
|
||||||
Returns:
|
|
||||||
bool: Whether the model is in stream mode.
|
|
||||||
"""
|
|
||||||
return self.model_config_dict.get("stream", False)
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from openai import Stream
|
|
||||||
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.types import (
|
|
||||||
ChatCompletion,
|
|
||||||
ChatCompletionChunk,
|
|
||||||
ModelType,
|
|
||||||
UnifiedModelType,
|
|
||||||
)
|
|
||||||
from camel.utils import BaseTokenCounter
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelBackend(ABC):
|
|
||||||
r"""Base class for different model backends.
|
|
||||||
It may be OpenAI API, a local LLM, a stub for unit tests, etc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (Union[ModelType, str]): Model for which a backend is
|
|
||||||
created.
|
|
||||||
model_config_dict (Optional[Dict[str, Any]], optional): A config
|
|
||||||
dictionary. (default: :obj:`{}`)
|
|
||||||
api_key (Optional[str], optional): The API key for authenticating
|
|
||||||
with the model service. (default: :obj:`None`)
|
|
||||||
url (Optional[str], optional): The url to the model service.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
token_counter (Optional[BaseTokenCounter], optional): Token
|
|
||||||
counter to use for the model. If not provided,
|
|
||||||
:obj:`OpenAITokenCounter` will be used. (default: :obj:`None`)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: Union[ModelType, str],
|
|
||||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token_counter: Optional[BaseTokenCounter] = None,
|
|
||||||
) -> None:
|
|
||||||
self.model_type: UnifiedModelType = UnifiedModelType(model_type)
|
|
||||||
if model_config_dict is None:
|
|
||||||
model_config_dict = {}
|
|
||||||
self.model_config_dict = model_config_dict
|
|
||||||
self._api_key = api_key
|
|
||||||
self._url = url
|
|
||||||
self._token_counter = token_counter
|
|
||||||
self.check_model_config()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
r"""Initialize the token counter for the model backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseTokenCounter: The token counter following the model's
|
|
||||||
tokenization style.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
messages: List[OpenAIMessage],
|
|
||||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
r"""Runs the query to the backend model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[OpenAIMessage]): Message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
`ChatCompletion` in the non-stream mode, or
|
|
||||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def check_model_config(self):
|
|
||||||
r"""Check whether the input model configuration contains unexpected
|
|
||||||
arguments
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model configuration dictionary contains any
|
|
||||||
unexpected argument for this model class.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
|
|
||||||
r"""Count the number of tokens in the messages using the specific
|
|
||||||
tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[Dict]): message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Number of tokens in the messages.
|
|
||||||
"""
|
|
||||||
return self.token_counter.count_tokens_from_messages(messages)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_limit(self) -> int:
|
|
||||||
r"""Returns the maximum token limit for a given model.
|
|
||||||
|
|
||||||
This method retrieves the maximum token limit either from the
|
|
||||||
`model_config_dict` or from the model's default token limit.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The maximum token limit for the given model.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
self.model_config_dict.get("max_tokens")
|
|
||||||
or self.model_type.token_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
r"""Returns whether the model is in stream mode, which sends partial
|
|
||||||
results each time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether the model is in stream mode.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
@@ -1,282 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
import ast
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from cohere.types import ChatMessageV2, ChatResponse
|
|
||||||
|
|
||||||
from camel.configs import COHERE_API_PARAMS, CohereConfig
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.models import BaseModelBackend
|
|
||||||
from camel.types import ChatCompletion, ModelType
|
|
||||||
from camel.utils import (
|
|
||||||
BaseTokenCounter,
|
|
||||||
OpenAITokenCounter,
|
|
||||||
api_keys_required,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
||||||
from agentops import LLMEvent, record
|
|
||||||
else:
|
|
||||||
raise ImportError
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
LLMEvent = None
|
|
||||||
|
|
||||||
|
|
||||||
class CohereModel(BaseModelBackend):
|
|
||||||
r"""Cohere API in a unified BaseModelBackend interface."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: Union[ModelType, str],
|
|
||||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token_counter: Optional[BaseTokenCounter] = None,
|
|
||||||
):
|
|
||||||
import cohere
|
|
||||||
|
|
||||||
if model_config_dict is None:
|
|
||||||
model_config_dict = CohereConfig().as_dict()
|
|
||||||
|
|
||||||
api_key = api_key or os.environ.get("COHERE_API_KEY")
|
|
||||||
url = url or os.environ.get("COHERE_API_BASE_URL")
|
|
||||||
super().__init__(
|
|
||||||
model_type, model_config_dict, api_key, url, token_counter
|
|
||||||
)
|
|
||||||
self._client = cohere.ClientV2(api_key=self._api_key)
|
|
||||||
|
|
||||||
def _to_openai_response(self, response: 'ChatResponse') -> ChatCompletion:
|
|
||||||
if response.usage and response.usage.tokens:
|
|
||||||
input_tokens = response.usage.tokens.input_tokens or 0
|
|
||||||
output_tokens = response.usage.tokens.output_tokens or 0
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": input_tokens,
|
|
||||||
"completion_tokens": output_tokens,
|
|
||||||
"total_tokens": input_tokens + output_tokens,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
usage = {}
|
|
||||||
|
|
||||||
tool_calls = response.message.tool_calls
|
|
||||||
choices = []
|
|
||||||
if tool_calls:
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
openai_tool_calls = [
|
|
||||||
dict(
|
|
||||||
id=tool_call.id,
|
|
||||||
function={
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": tool_call.function.arguments,
|
|
||||||
}
|
|
||||||
if tool_call.function
|
|
||||||
else {},
|
|
||||||
type=tool_call.type,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
choice = dict(
|
|
||||||
index=None,
|
|
||||||
message={
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.message.tool_plan,
|
|
||||||
"tool_calls": openai_tool_calls,
|
|
||||||
},
|
|
||||||
finish_reason=response.finish_reason
|
|
||||||
if response.finish_reason
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
choices.append(choice)
|
|
||||||
|
|
||||||
else:
|
|
||||||
openai_tool_calls = None
|
|
||||||
|
|
||||||
choice = dict(
|
|
||||||
index=None,
|
|
||||||
message={
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.message.content[0].text, # type: ignore[union-attr,index]
|
|
||||||
"tool_calls": openai_tool_calls,
|
|
||||||
},
|
|
||||||
finish_reason=response.finish_reason
|
|
||||||
if response.finish_reason
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
choices.append(choice)
|
|
||||||
|
|
||||||
obj = ChatCompletion.construct(
|
|
||||||
id=response.id,
|
|
||||||
choices=choices,
|
|
||||||
created=None,
|
|
||||||
model=self.model_type,
|
|
||||||
object="chat.completion",
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def _to_cohere_chatmessage(
|
|
||||||
self, messages: List[OpenAIMessage]
|
|
||||||
) -> List["ChatMessageV2"]:
|
|
||||||
from cohere.types import ToolCallV2Function
|
|
||||||
from cohere.types.chat_message_v2 import (
|
|
||||||
AssistantChatMessageV2,
|
|
||||||
SystemChatMessageV2,
|
|
||||||
ToolCallV2,
|
|
||||||
ToolChatMessageV2,
|
|
||||||
UserChatMessageV2,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call_id = None
|
|
||||||
new_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role")
|
|
||||||
content = msg.get("content")
|
|
||||||
function_call = msg.get("function_call")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
new_message = UserChatMessageV2(role="user", content=content) # type: ignore[arg-type]
|
|
||||||
elif role in {"tool", "function"}:
|
|
||||||
new_message = ToolChatMessageV2(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=tool_call_id, # type: ignore[arg-type]
|
|
||||||
content=content, # type: ignore[assignment,arg-type]
|
|
||||||
)
|
|
||||||
elif role == "assistant":
|
|
||||||
if not function_call:
|
|
||||||
new_message = AssistantChatMessageV2( # type: ignore[assignment]
|
|
||||||
role="assistant",
|
|
||||||
content=content, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
arguments = function_call.get("arguments") # type: ignore[attr-defined]
|
|
||||||
arguments_dict = ast.literal_eval(arguments)
|
|
||||||
arguments_json = json.dumps(arguments_dict)
|
|
||||||
|
|
||||||
assis_tool_call_id = str(uuid.uuid4())
|
|
||||||
tool_call_id = assis_tool_call_id
|
|
||||||
new_message = AssistantChatMessageV2( # type: ignore[assignment]
|
|
||||||
role="assistant",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCallV2(
|
|
||||||
id=assis_tool_call_id,
|
|
||||||
type="function",
|
|
||||||
function=ToolCallV2Function(
|
|
||||||
name=function_call.get("name"), # type: ignore[attr-defined]
|
|
||||||
arguments=arguments_json, # type: ignore[attr-defined]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
content=content, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
elif role == "system":
|
|
||||||
new_message = SystemChatMessageV2( # type: ignore[assignment]
|
|
||||||
role="system",
|
|
||||||
content=content, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported message role: {role}")
|
|
||||||
|
|
||||||
new_messages.append(new_message)
|
|
||||||
return new_messages # type: ignore[return-value]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
r"""Initialize the token counter for the model backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseTokenCounter: The token counter following the model's
|
|
||||||
tokenization style.
|
|
||||||
"""
|
|
||||||
if not self._token_counter:
|
|
||||||
self._token_counter = OpenAITokenCounter(
|
|
||||||
model=ModelType.GPT_4O_MINI
|
|
||||||
)
|
|
||||||
return self._token_counter
|
|
||||||
|
|
||||||
@api_keys_required("COHERE_API_KEY")
|
|
||||||
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
|
|
||||||
r"""Runs inference of Cohere chat completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[OpenAIMessage]): Message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
Returns:
|
|
||||||
ChatCompletion.
|
|
||||||
"""
|
|
||||||
from cohere.core.api_error import ApiError
|
|
||||||
|
|
||||||
cohere_messages = self._to_cohere_chatmessage(messages)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self._client.chat(
|
|
||||||
messages=cohere_messages,
|
|
||||||
model=self.model_type,
|
|
||||||
**self.model_config_dict,
|
|
||||||
)
|
|
||||||
except ApiError as e:
|
|
||||||
logging.error(f"Cohere API Error: {e.status_code}")
|
|
||||||
logging.error(f"Error body: {e.body}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Unexpected error when calling Cohere API: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
openai_response = self._to_openai_response(response)
|
|
||||||
|
|
||||||
# Add AgentOps LLM Event tracking
|
|
||||||
if LLMEvent:
|
|
||||||
llm_event = LLMEvent(
|
|
||||||
thread_id=openai_response.id,
|
|
||||||
prompt=" ".join(
|
|
||||||
[message.get("content") for message in messages] # type: ignore[misc]
|
|
||||||
),
|
|
||||||
prompt_tokens=openai_response.usage.prompt_tokens, # type: ignore[union-attr]
|
|
||||||
completion=openai_response.choices[0].message.content,
|
|
||||||
completion_tokens=openai_response.usage.completion_tokens, # type: ignore[union-attr]
|
|
||||||
model=self.model_type,
|
|
||||||
)
|
|
||||||
record(llm_event)
|
|
||||||
|
|
||||||
return openai_response
|
|
||||||
|
|
||||||
def check_model_config(self):
|
|
||||||
r"""Check whether the model configuration contains any unexpected
|
|
||||||
arguments to Cohere API.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model configuration dictionary contains any
|
|
||||||
unexpected arguments to Cohere API.
|
|
||||||
"""
|
|
||||||
for param in self.model_config_dict:
|
|
||||||
if param not in COHERE_API_PARAMS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected argument `{param}` is "
|
|
||||||
"input into Cohere model backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
r"""Returns whether the model is in stream mode, which sends partial
|
|
||||||
results each time. Current it's not supported.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether the model is in stream mode.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from openai import OpenAI, Stream
|
|
||||||
|
|
||||||
from camel.configs import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
|
||||||
from camel.logger import get_logger
|
|
||||||
from camel.messages import OpenAIMessage
|
|
||||||
from camel.models.base_model import BaseModelBackend
|
|
||||||
from camel.types import (
|
|
||||||
ChatCompletion,
|
|
||||||
ChatCompletionChunk,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_keys_required
|
|
||||||
from retry import retry
|
|
||||||
import json
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekModel(BaseModelBackend):
|
|
||||||
r"""DeepSeek API in a unified BaseModelBackend interface.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (Union[ModelType, str]): Model for which a backend is
|
|
||||||
created.
|
|
||||||
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
|
|
||||||
that will be fed into:obj:`openai.ChatCompletion.create()`. If
|
|
||||||
:obj:`None`, :obj:`DeepSeekConfig().as_dict()` will be used.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
api_key (Optional[str], optional): The API key for authenticating with
|
|
||||||
the DeepSeek service. (default: :obj:`None`)
|
|
||||||
url (Optional[str], optional): The url to the DeepSeek service.
|
|
||||||
(default: :obj:`https://api.deepseek.com`)
|
|
||||||
token_counter (Optional[BaseTokenCounter], optional): Token counter to
|
|
||||||
use for the model. If not provided, :obj:`OpenAITokenCounter`
|
|
||||||
will be used. (default: :obj:`None`)
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://api-docs.deepseek.com/
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_type: Union[ModelType, str],
|
|
||||||
model_config_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token_counter: Optional[BaseTokenCounter] = None,
|
|
||||||
) -> None:
|
|
||||||
if model_config_dict is None:
|
|
||||||
model_config_dict = DeepSeekConfig().as_dict()
|
|
||||||
api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
|
||||||
url = url or os.environ.get(
|
|
||||||
"DEEPSEEK_API_BASE_URL",
|
|
||||||
"https://api.deepseek.com",
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_type, model_config_dict, api_key, url, token_counter
|
|
||||||
)
|
|
||||||
|
|
||||||
self._client = OpenAI(
|
|
||||||
timeout=180,
|
|
||||||
max_retries=3,
|
|
||||||
api_key=self._api_key,
|
|
||||||
base_url=self._url,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_counter(self) -> BaseTokenCounter:
|
|
||||||
r"""Initialize the token counter for the model backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseTokenCounter: The token counter following the model's
|
|
||||||
tokenization style.
|
|
||||||
"""
|
|
||||||
if not self._token_counter:
|
|
||||||
self._token_counter = OpenAITokenCounter(
|
|
||||||
model=ModelType.GPT_4O_MINI
|
|
||||||
)
|
|
||||||
return self._token_counter
|
|
||||||
|
|
||||||
@retry((ValueError, TypeError, json.decoder.JSONDecodeError), delay=10, logger=logger)
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
messages: List[OpenAIMessage],
|
|
||||||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
r"""Runs inference of DeepSeek chat completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[OpenAIMessage]): Message list with the chat history
|
|
||||||
in OpenAI API format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
|
|
||||||
`ChatCompletion` in the non-stream mode, or
|
|
||||||
`Stream[ChatCompletionChunk]` in the stream mode.
|
|
||||||
"""
|
|
||||||
# deepseek reasoner has limitations
|
|
||||||
# reference: https://api-docs.deepseek.com/guides/reasoning_model#api-parameters
|
|
||||||
if self.model_type in [
|
|
||||||
ModelType.DEEPSEEK_REASONER,
|
|
||||||
]:
|
|
||||||
import re
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"You are using a DeepSeek Reasoner model, "
|
|
||||||
"which has certain limitations, reference: "
|
|
||||||
"`https://api-docs.deepseek.com/guides/reasoning_model#api-parameters`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check and remove unsupported parameters and reset the fixed
|
|
||||||
# parameters
|
|
||||||
unsupported_keys = [
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logprobs",
|
|
||||||
"top_logprobs",
|
|
||||||
"tools",
|
|
||||||
]
|
|
||||||
for key in unsupported_keys:
|
|
||||||
if key in self.model_config_dict:
|
|
||||||
del self.model_config_dict[key]
|
|
||||||
# Remove thinking content from messages before sending to API
|
|
||||||
# This ensures only the final response is sent, excluding
|
|
||||||
# intermediate thought processes
|
|
||||||
messages = [
|
|
||||||
{ # type: ignore[misc]
|
|
||||||
**msg,
|
|
||||||
'content': re.sub(
|
|
||||||
r'<think>.*?</think>',
|
|
||||||
'',
|
|
||||||
msg['content'], # type: ignore[arg-type]
|
|
||||||
flags=re.DOTALL,
|
|
||||||
).strip(),
|
|
||||||
}
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
response = self._client.chat.completions.create(
|
|
||||||
messages=messages,
|
|
||||||
model=self.model_type,
|
|
||||||
**self.model_config_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle reasoning content with <think> tags at the beginning
|
|
||||||
if (
|
|
||||||
self.model_type
|
|
||||||
in [
|
|
||||||
ModelType.DEEPSEEK_REASONER,
|
|
||||||
]
|
|
||||||
and os.environ.get("GET_REASONING_CONTENT", "false").lower()
|
|
||||||
== "true"
|
|
||||||
):
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content
|
|
||||||
combined_content = (
|
|
||||||
f"<think>\n{reasoning_content}\n</think>\n"
|
|
||||||
if reasoning_content
|
|
||||||
else ""
|
|
||||||
) + response.choices[0].message.content
|
|
||||||
|
|
||||||
response = ChatCompletion.construct(
|
|
||||||
id=response.id,
|
|
||||||
choices=[
|
|
||||||
dict(
|
|
||||||
index=response.choices[0].index,
|
|
||||||
message={
|
|
||||||
"role": response.choices[0].message.role,
|
|
||||||
"content": combined_content,
|
|
||||||
"tool_calls": None,
|
|
||||||
},
|
|
||||||
finish_reason=response.choices[0].finish_reason
|
|
||||||
if response.choices[0].finish_reason
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=response.created,
|
|
||||||
model=response.model,
|
|
||||||
object="chat.completion",
|
|
||||||
usage=response.usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def check_model_config(self):
|
|
||||||
r"""Check whether the model configuration contains any
|
|
||||||
unexpected arguments to DeepSeek API.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model configuration dictionary contains any
|
|
||||||
unexpected arguments to DeepSeek API.
|
|
||||||
"""
|
|
||||||
for param in self.model_config_dict:
|
|
||||||
if param not in DEEPSEEK_API_PARAMS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected argument `{param}` is "
|
|
||||||
"input into DeepSeek model backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
r"""Returns whether the model is in stream mode, which sends partial
|
|
||||||
results each time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether the model is in stream mode.
|
|
||||||
"""
|
|
||||||
return self.model_config_dict.get("stream", False)
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class FishAudioModel:
|
|
||||||
r"""Provides access to FishAudio's Text-to-Speech (TTS) and Speech_to_Text
|
|
||||||
(STT) models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
r"""Initialize an instance of FishAudioModel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (Optional[str]): API key for FishAudio service. If not
|
|
||||||
provided, the environment variable `FISHAUDIO_API_KEY` will be
|
|
||||||
used.
|
|
||||||
url (Optional[str]): Base URL for FishAudio API. If not provided,
|
|
||||||
the environment variable `FISHAUDIO_API_BASE_URL` will be used.
|
|
||||||
"""
|
|
||||||
from fish_audio_sdk import Session
|
|
||||||
|
|
||||||
self._api_key = api_key or os.environ.get("FISHAUDIO_API_KEY")
|
|
||||||
self._url = url or os.environ.get(
|
|
||||||
"FISHAUDIO_API_BASE_URL", "https://api.fish.audio"
|
|
||||||
)
|
|
||||||
self.session = Session(apikey=self._api_key, base_url=self._url)
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_speech(
|
|
||||||
self,
|
|
||||||
input: str,
|
|
||||||
storage_path: str,
|
|
||||||
reference_id: Optional[str] = None,
|
|
||||||
reference_audio: Optional[str] = None,
|
|
||||||
reference_audio_text: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
r"""Convert text to speech and save the output to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text (str): The text to convert to speech.
|
|
||||||
storage_path (str): The file path where the resulting speech will
|
|
||||||
be saved.
|
|
||||||
reference_id (Optional[str]): An optional reference ID to
|
|
||||||
associate with the request. (default: :obj:`None`)
|
|
||||||
reference_audio (Optional[str]): Path to an audio file for
|
|
||||||
reference speech. (default: :obj:`None`)
|
|
||||||
reference_audio_text (Optional[str]): Text for the reference audio.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
**kwargs (Any): Additional parameters to pass to the TTS request.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If the reference audio file cannot be found.
|
|
||||||
"""
|
|
||||||
from fish_audio_sdk import ReferenceAudio, TTSRequest
|
|
||||||
|
|
||||||
directory = os.path.dirname(storage_path)
|
|
||||||
if directory and not os.path.exists(directory):
|
|
||||||
os.makedirs(directory)
|
|
||||||
|
|
||||||
if not reference_audio:
|
|
||||||
with open(f"{storage_path}", "wb") as f:
|
|
||||||
for chunk in self.session.tts(
|
|
||||||
TTSRequest(reference_id=reference_id, text=input, **kwargs)
|
|
||||||
):
|
|
||||||
f.write(chunk)
|
|
||||||
else:
|
|
||||||
if not os.path.exists(reference_audio):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Reference audio file not found: {reference_audio}"
|
|
||||||
)
|
|
||||||
if not reference_audio_text:
|
|
||||||
raise ValueError("reference_audio_text should be provided")
|
|
||||||
with open(f"{reference_audio}", "rb") as audio_file:
|
|
||||||
with open(f"{storage_path}", "wb") as f:
|
|
||||||
for chunk in self.session.tts(
|
|
||||||
TTSRequest(
|
|
||||||
text=input,
|
|
||||||
references=[
|
|
||||||
ReferenceAudio(
|
|
||||||
audio=audio_file.read(),
|
|
||||||
text=reference_audio_text,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
def speech_to_text(
|
|
||||||
self,
|
|
||||||
audio_file_path: str,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
ignore_timestamps: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
r"""Convert speech to text from an audio file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_file_path (str): The path to the audio file to transcribe.
|
|
||||||
language (Optional[str]): The language of the audio. (default:
|
|
||||||
:obj:`None`)
|
|
||||||
ignore_timestamps (Optional[bool]): Whether to ignore timestamps.
|
|
||||||
(default: :obj:`None`)
|
|
||||||
**kwargs (Any): Additional parameters to pass to the STT request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The transcribed text from the audio.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If the audio file cannot be found.
|
|
||||||
"""
|
|
||||||
from fish_audio_sdk import ASRRequest
|
|
||||||
|
|
||||||
if not os.path.exists(audio_file_path):
|
|
||||||
raise FileNotFoundError(f"Audio file not found: {audio_file_path}")
|
|
||||||
|
|
||||||
with open(f"{audio_file_path}", "rb") as audio_file:
|
|
||||||
audio_data = audio_file.read()
|
|
||||||
|
|
||||||
response = self.session.asr(
|
|
||||||
ASRRequest(
|
|
||||||
audio=audio_data,
|
|
||||||
language=language,
|
|
||||||
ignore_timestamps=ignore_timestamps,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return response.text
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user