mirror of
https://github.com/camel-ai/owl.git
synced 2026-03-22 05:57:17 +08:00
initial update workforce code for gaia
This commit is contained in:
41
.env_example
Normal file
41
.env_example
Normal file
@@ -0,0 +1,41 @@
|
||||
# To use these environment variables:
|
||||
# 1. Populate the .env file with your API keys.
|
||||
# 2. Include the following code snippet in your Python script:
|
||||
# from dotenv import load_dotenv
|
||||
# import os
|
||||
#
|
||||
# load_dotenv() # Load environment variables from .env file
|
||||
|
||||
#===========================================
|
||||
# Models API
|
||||
#===========================================
|
||||
|
||||
# OpenAI API (https://platform.openai.com/signup)
|
||||
OPENAI_API_KEY="Fill your API key here"
|
||||
|
||||
# Anthropic API (https://www.anthropic.com/)
|
||||
ANTHROPIC_API_KEY="Fill your API key here"
|
||||
|
||||
# Hugging Face API (https://huggingface.co/join)
|
||||
HF_TOKEN="Fill your API key here"
|
||||
|
||||
# Azure OpenAI API (https://azure.microsoft.com/products/cognitive-services/openai-service/)
|
||||
AZURE_OPENAI_API_KEY="Fill your API key here"
|
||||
AZURE_API_VERSION="Fill your API Version here"
|
||||
AZURE_DEPLOYMENT_NAME="Fill your Deployment Name here"
|
||||
AZURE_OPENAI_BASE_URL="Fill your Base URL here"
|
||||
|
||||
#===========================================
|
||||
# Tools & Services API
|
||||
#===========================================
|
||||
|
||||
# Google Search API (https://developers.google.com/custom-search/v1/overview)
|
||||
GOOGLE_API_KEY="Fill your API key here"
|
||||
SEARCH_ENGINE_ID="Fill your Search Engine ID here"
|
||||
|
||||
|
||||
# Firecrawl API (https://www.firecrawl.dev/)
|
||||
FIRECRAWL_API_KEY="Fill your API key here"
|
||||
|
||||
# Chunkr API (https://chunkr.ai/)
|
||||
CHUNKR_API_KEY="Fill your API key here"
|
||||
63
.gitignore
vendored
Normal file
63
.gitignore
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
**/__pycache__/
|
||||
*/__pycache__/*
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
.dist
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
.DS_Store
|
||||
|
||||
# Project specific
|
||||
data/gaia
|
||||
tmp
|
||||
.env
|
||||
utils/__pycache__/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
log/
|
||||
|
||||
# Coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
coverage.xml
|
||||
*.cover
|
||||
|
||||
camel/types/__pycache__/
|
||||
camel/__pycache__/
|
||||
camel/utils/__pycache_/
|
||||
|
||||
data/*
|
||||
|
||||
25
camel/__init__.py
Normal file
25
camel/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from camel.logger import disable_logging, enable_logging, set_log_level
|
||||
|
||||
__version__ = '0.2.47'
|
||||
|
||||
__all__ = [
|
||||
'__version__',
|
||||
'camel',
|
||||
'disable_logging',
|
||||
'enable_logging',
|
||||
'set_log_level',
|
||||
]
|
||||
46
camel/agents/__init__.py
Normal file
46
camel/agents/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseAgent
|
||||
from .chat_agent import ChatAgent
|
||||
from .critic_agent import CriticAgent
|
||||
from .embodied_agent import EmbodiedAgent
|
||||
from .knowledge_graph_agent import KnowledgeGraphAgent
|
||||
from .repo_agent import RepoAgent
|
||||
from .role_assignment_agent import RoleAssignmentAgent
|
||||
from .search_agent import SearchAgent
|
||||
from .task_agent import (
|
||||
TaskCreationAgent,
|
||||
TaskPlannerAgent,
|
||||
TaskPrioritizationAgent,
|
||||
TaskSpecifyAgent,
|
||||
)
|
||||
from .tool_agents.base import BaseToolAgent
|
||||
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseAgent',
|
||||
'ChatAgent',
|
||||
'TaskSpecifyAgent',
|
||||
'TaskPlannerAgent',
|
||||
'TaskCreationAgent',
|
||||
'TaskPrioritizationAgent',
|
||||
'CriticAgent',
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
'EmbodiedAgent',
|
||||
'RoleAssignmentAgent',
|
||||
'SearchAgent',
|
||||
'KnowledgeGraphAgent',
|
||||
'RepoAgent',
|
||||
]
|
||||
41
camel/agents/_types.py
Normal file
41
camel/agents/_types.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream, Stream
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.types import ChatCompletion
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
r"""The request for tool calling."""
|
||||
|
||||
tool_name: str
|
||||
args: Dict[str, Any]
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
r"""The response from the model."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
response: Union[ChatCompletion, Stream, AsyncStream]
|
||||
tool_call_requests: Optional[List[ToolCallRequest]]
|
||||
output_messages: List[BaseMessage]
|
||||
finish_reasons: List[str]
|
||||
usage_dict: Dict[str, Any]
|
||||
response_id: str
|
||||
188
camel/agents/_utils.py
Normal file
188
camel/agents/_utils.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from camel.agents._types import ToolCallRequest
|
||||
from camel.toolkits import FunctionTool
|
||||
from camel.types import Choice
|
||||
from camel.types.agents import ToolCallingRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_tool_prompt(tool_schema_list: List[Dict[str, Any]]) -> str:
|
||||
r"""Generates a tool prompt based on the provided tool schema list.
|
||||
|
||||
Returns:
|
||||
str: A string representing the tool prompt.
|
||||
"""
|
||||
tool_prompts = []
|
||||
|
||||
for tool in tool_schema_list:
|
||||
tool_info = tool["function"]
|
||||
tool_name = tool_info["name"]
|
||||
tool_description = tool_info["description"]
|
||||
tool_json = json.dumps(tool_info, indent=4, ensure_ascii=False)
|
||||
|
||||
prompt = (
|
||||
f"Use the function '{tool_name}' to '{tool_description}':\n"
|
||||
f"{tool_json}\n"
|
||||
)
|
||||
tool_prompts.append(prompt)
|
||||
|
||||
tool_prompt_str = "\n".join(tool_prompts)
|
||||
|
||||
final_prompt = textwrap.dedent(
|
||||
f"""\
|
||||
You have access to the following functions:
|
||||
|
||||
{tool_prompt_str}
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls.
|
||||
""" # noqa: E501
|
||||
)
|
||||
return final_prompt
|
||||
|
||||
|
||||
def extract_tool_call(
|
||||
content: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
r"""Extract the tool call from the model response, if present.
|
||||
|
||||
Args:
|
||||
response (Any): The model's response object.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The parsed tool call if present,
|
||||
otherwise None.
|
||||
"""
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
match = re.search(function_regex, content)
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
function_name, args_string = match.groups()
|
||||
try:
|
||||
args = json.loads(args_string)
|
||||
return {"function": function_name, "arguments": args}
|
||||
except json.JSONDecodeError as error:
|
||||
logger.error(f"Error parsing function arguments: {error}")
|
||||
return None
|
||||
|
||||
|
||||
def safe_model_dump(obj) -> Dict[str, Any]:
|
||||
r"""Safely dump a Pydantic model to a dictionary.
|
||||
|
||||
This method attempts to use the `model_dump` method if available,
|
||||
otherwise it falls back to the `dict` method.
|
||||
"""
|
||||
# Check if the `model_dump` method exists (Pydantic v2)
|
||||
if hasattr(obj, "model_dump"):
|
||||
return obj.model_dump()
|
||||
# Fallback to `dict()` method (Pydantic v1)
|
||||
elif hasattr(obj, "dict"):
|
||||
return obj.dict()
|
||||
else:
|
||||
raise TypeError("The object is not a Pydantic model")
|
||||
|
||||
|
||||
def convert_to_function_tool(
|
||||
tool: Union[FunctionTool, Callable],
|
||||
) -> FunctionTool:
|
||||
r"""Convert a tool to a FunctionTool from Callable."""
|
||||
return tool if isinstance(tool, FunctionTool) else FunctionTool(tool)
|
||||
|
||||
|
||||
def convert_to_schema(
|
||||
tool: Union[FunctionTool, Callable, Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
r"""Convert a tool to a schema from Callable or FunctionTool."""
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.get_openai_tool_schema()
|
||||
elif callable(tool):
|
||||
return FunctionTool(tool).get_openai_tool_schema()
|
||||
else:
|
||||
return tool
|
||||
|
||||
|
||||
def get_info_dict(
|
||||
session_id: Optional[str],
|
||||
usage: Optional[Dict[str, int]],
|
||||
termination_reasons: List[str],
|
||||
num_tokens: int,
|
||||
tool_calls: List[ToolCallingRecord],
|
||||
external_tool_call_requests: Optional[List[ToolCallRequest]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Returns a dictionary containing information about the chat session.
|
||||
|
||||
Args:
|
||||
session_id (str, optional): The ID of the chat session.
|
||||
usage (Dict[str, int], optional): Information about the usage of
|
||||
the LLM.
|
||||
termination_reasons (List[str]): The reasons for the termination
|
||||
of the chat session.
|
||||
num_tokens (int): The number of tokens used in the chat session.
|
||||
tool_calls (List[ToolCallingRecord]): The list of function
|
||||
calling records, containing the information of called tools.
|
||||
external_tool_call_requests (Optional[List[ToolCallRequest]]): The
|
||||
requests for external tool calls.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The chat session information.
|
||||
"""
|
||||
return {
|
||||
"id": session_id,
|
||||
"usage": usage,
|
||||
"termination_reasons": termination_reasons,
|
||||
"num_tokens": num_tokens,
|
||||
"tool_calls": tool_calls,
|
||||
"external_tool_call_requests": external_tool_call_requests,
|
||||
}
|
||||
|
||||
|
||||
def handle_logprobs(choice: Choice) -> Optional[List[Dict[str, Any]]]:
|
||||
if choice.logprobs is None:
|
||||
return None
|
||||
|
||||
tokens_logprobs = choice.logprobs.content
|
||||
|
||||
if tokens_logprobs is None:
|
||||
return None
|
||||
|
||||
return [
|
||||
{
|
||||
"token": token_logprob.token,
|
||||
"logprob": token_logprob.logprob,
|
||||
"top_logprobs": [
|
||||
(top_logprob.token, top_logprob.logprob)
|
||||
for top_logprob in token_logprob.top_logprobs
|
||||
],
|
||||
}
|
||||
for token_logprob in tokens_logprobs
|
||||
]
|
||||
29
camel/agents/base.py
Normal file
29
camel/agents/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
r"""An abstract base class for all CAMEL agents."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
1407
camel/agents/chat_agent.py
Normal file
1407
camel/agents/chat_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
202
camel/agents/critic_agent.py
Normal file
202
camel/agents/critic_agent.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.memories import AgentMemory
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import get_first_int, print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="CriticAgent")
|
||||
class CriticAgent(ChatAgent):
|
||||
r"""A class for the critic agent that assists in selecting an option.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the critic
|
||||
agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`6`)
|
||||
retry_attempts (int, optional): The number of retry attempts if the
|
||||
critic fails to return a valid option. (default: :obj:`2`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the menu options displayed to the
|
||||
user. (default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
memory: Optional[AgentMemory] = None,
|
||||
message_window_size: int = 6,
|
||||
retry_attempts: int = 2,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
memory=memory,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
self.options_dict: Dict[str, str] = dict()
|
||||
self.retry_attempts = retry_attempts
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
|
||||
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
||||
r"""Flattens the options to the critic.
|
||||
|
||||
Args:
|
||||
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
||||
|
||||
Returns:
|
||||
str: A string containing the flattened options to the critic.
|
||||
"""
|
||||
options = [message.content for message in messages]
|
||||
flatten_options = (
|
||||
f"> Proposals from "
|
||||
f"{messages[0].role_name} ({messages[0].role_type}). "
|
||||
"Please choose an option:\n"
|
||||
)
|
||||
for index, option in enumerate(options):
|
||||
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
||||
self.options_dict[str(index + 1)] = option
|
||||
format = (
|
||||
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
||||
"and then your explanation and comparison: "
|
||||
)
|
||||
return flatten_options + format
|
||||
|
||||
def get_option(self, input_message: BaseMessage) -> str:
|
||||
r"""Gets the option selected by the critic.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): A `BaseMessage` object representing
|
||||
the input message.
|
||||
|
||||
Returns:
|
||||
str: The option selected by the critic.
|
||||
"""
|
||||
# TODO: Add support for editing options by the critic.
|
||||
msg_content = input_message.content
|
||||
i = 0
|
||||
while i < self.retry_attempts:
|
||||
critic_response = self.step(input_message)
|
||||
|
||||
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
||||
raise RuntimeError("Got None critic messages.")
|
||||
if critic_response.terminated:
|
||||
raise RuntimeError("Critic step failed.")
|
||||
|
||||
critic_msg = critic_response.msg
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + "\n> Critic response: "
|
||||
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
||||
)
|
||||
choice = self.parse_critic(critic_msg)
|
||||
|
||||
if choice in self.options_dict:
|
||||
return self.options_dict[choice]
|
||||
else:
|
||||
input_message = BaseMessage(
|
||||
role_name=input_message.role_name,
|
||||
role_type=input_message.role_type,
|
||||
meta_dict=input_message.meta_dict,
|
||||
content="> Invalid choice. Please choose again.\n"
|
||||
+ msg_content,
|
||||
)
|
||||
i += 1
|
||||
warnings.warn(
|
||||
"Critic failed to get a valid option. "
|
||||
f"After {self.retry_attempts} attempts. "
|
||||
"Returning a random option."
|
||||
)
|
||||
return random.choice(list(self.options_dict.values()))
|
||||
|
||||
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
||||
r"""Parses the critic's message and extracts the choice.
|
||||
|
||||
Args:
|
||||
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
||||
critic's response.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The critic's choice as a string, or None if the
|
||||
message could not be parsed.
|
||||
"""
|
||||
choice = str(get_first_int(critic_msg.content))
|
||||
return choice
|
||||
|
||||
def reduce_step(
|
||||
self,
|
||||
input_messages: Sequence[BaseMessage],
|
||||
) -> ChatAgentResponse:
|
||||
r"""Performs one step of the conversation by flattening options to the
|
||||
critic, getting the option, and parsing the choice.
|
||||
|
||||
Args:
|
||||
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
||||
objects.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
||||
critic's choice.
|
||||
"""
|
||||
meta_chat_message = BaseMessage(
|
||||
role_name=input_messages[0].role_name,
|
||||
role_type=input_messages[0].role_type,
|
||||
meta_dict=input_messages[0].meta_dict,
|
||||
content="",
|
||||
)
|
||||
|
||||
flatten_options = self.flatten_options(input_messages)
|
||||
if self.verbose:
|
||||
print_text_animated(
|
||||
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
||||
)
|
||||
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
||||
|
||||
option = self.get_option(input_msg)
|
||||
output_msg = meta_chat_message.create_new_instance(option)
|
||||
|
||||
# TODO: The return `info` can be improved.
|
||||
return ChatAgentResponse(
|
||||
msgs=[output_msg],
|
||||
terminated=False,
|
||||
info={},
|
||||
)
|
||||
303
camel/agents/deductive_reasoner_agent.py
Normal file
303
camel/agents/deductive_reasoner_agent.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="DeductiveReasonerAgent")
|
||||
class DeductiveReasonerAgent(ChatAgent):
|
||||
r"""An agent responsible for deductive reasoning. Model of deductive
|
||||
reasoning:
|
||||
- L: A ⊕ C -> q * B
|
||||
- A represents the known starting state.
|
||||
- B represents the known target state.
|
||||
- C represents the conditions required to transition from A to B.
|
||||
- Q represents the quality or effectiveness of the transition from
|
||||
A to B.
|
||||
- L represents the path or process from A to B.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Insight Agent",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def deduce_conditions_and_quality(
|
||||
self,
|
||||
starting_state: str,
|
||||
target_state: str,
|
||||
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
||||
r"""Derives the conditions and quality from the starting state and the
|
||||
target state based on the model of the deductive reasoning and the
|
||||
knowledge base. It can optionally consider the roles involved in the
|
||||
scenario, which allows tailoring the output more closely to the AI
|
||||
agent's environment.
|
||||
|
||||
Args:
|
||||
starting_state (str): The initial or starting state from which
|
||||
conditions are deduced.
|
||||
target_state (str): The target state of the task.
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
||||
descriptions of the roles. (default: :obj:`None`)
|
||||
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
||||
dictionary describing the roles involved in the scenario. This
|
||||
is optional and can be used to provide a context for the
|
||||
CAMEL's role-playing, enabling the generation of more relevant
|
||||
and tailored conditions and quality assessments. This could be
|
||||
generated using a `RoleAssignmentAgent()` or defined manually
|
||||
by the user.
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
||||
extracted data from the message. The dictionary contains three
|
||||
keys:
|
||||
- 'conditions': A list where each key is a condition ID and
|
||||
each value is the corresponding condition text.
|
||||
- 'labels': A list of label strings extracted from the message.
|
||||
- 'quality': A string of quality assessment strings extracted
|
||||
from the message.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
||||
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
||||
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
||||
CONTENT to help you complete the TASK.
|
||||
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
||||
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
||||
|
||||
===== MODELING OF DEDUCTIVE REASONING =====
|
||||
You are tasked with understanding a mathematical model based on the components
|
||||
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
||||
- $A$ represents the known starting state.
|
||||
- $B$ represents the known target state.
|
||||
- $C$ represents the conditions required to transition from $A$ to $B$.
|
||||
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
||||
$B$.
|
||||
- $L$ represents the path or process from $A$ to $B$.
|
||||
|
||||
===== THOUGHT OF DEDUCTIVE REASONING =====
|
||||
1. Define the Parameters of A and B:
|
||||
- Characterization: Before delving into transitions, thoroughly understand
|
||||
the nature and boundaries of both $A$ and $B$. This includes the type,
|
||||
properties, constraints, and possible interactions between the two.
|
||||
- Contrast and Compare: Highlight the similarities and differences between
|
||||
$A$ and $B$. This comparative analysis will give an insight into what
|
||||
needs changing and what remains constant.
|
||||
2. Historical & Empirical Analysis:
|
||||
- Previous Transitions according to the Knowledge Base of GPT: (if
|
||||
applicable) Extract conditions and patterns from the historical instances
|
||||
where a similar transition from a state comparable to $A$ moved towards
|
||||
$B$.
|
||||
- Scientific Principles: (if applicable) Consider the underlying
|
||||
scientific principles governing or related to the states and their
|
||||
transition. For example, if $A$ and $B$ are physical states, laws of
|
||||
physics might apply.
|
||||
3. Logical Deduction of Conditions ($C$):
|
||||
- Direct Path Analysis: What are the immediate and direct conditions
|
||||
required to move from $A$ to $B$?
|
||||
- Intermediate States: Are there states between $A$ and $B$ that must be
|
||||
traversed or can be used to make the transition smoother or more
|
||||
efficient? If yes, what is the content?
|
||||
- Constraints & Limitations: Identify potential barriers or restrictions
|
||||
in moving from $A$ to $B$. These can be external (e.g., environmental
|
||||
factors) or internal (properties of $A$ or $B$).
|
||||
- Resource and Information Analysis: What resources and information are
|
||||
required for the transition? This could be time, entity, factor, code
|
||||
language, software platform, unknowns, etc.
|
||||
- External Influences: Consider socio-economic, political, or
|
||||
environmental factors (if applicable) that could influence the transition
|
||||
conditions.
|
||||
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
||||
no matter how unconventional they might seem. Utilize analogies,
|
||||
metaphors, or brainstorming techniques to envision possible conditions or
|
||||
paths from $A$ to $B$.
|
||||
- The conditions $C$ should be multiple but in one sentence. And each
|
||||
condition should be concerned with one aspect/entity.
|
||||
4. Entity/Label Recognition of Conditions ($C$):
|
||||
- Identify and categorize entities of Conditions ($C$) such as the names,
|
||||
locations, dates, specific technical terms or contextual parameters that
|
||||
might be associated with events, innovations post-2022.
|
||||
- The output of the entities/labels will be used as tags or labels for
|
||||
semantic similarity searches. The entities/labels may be the words, or
|
||||
phrases, each of them should contain valuable, high information entropy
|
||||
information, and should be independent.
|
||||
- Ensure that the identified entities are formatted in a manner suitable
|
||||
for database indexing and retrieval. Organize the entities into
|
||||
categories, and combine the category with its instance into a continuous
|
||||
phrase, without using colons or other separators.
|
||||
- Format these entities for database indexing: output the category rather
|
||||
than its instance/content into a continuous phrase. For example, instead
|
||||
of "Jan. 02", identify it as "Event time".
|
||||
5. Quality Assessment ($Q$):
|
||||
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
||||
measures the resources used versus the desired outcome?
|
||||
- Effectiveness: Did the transition achieve the desired outcome or was the
|
||||
target state achieved as intended?
|
||||
- Safety & Risks: Assess any risks associated with the transition and the
|
||||
measures to mitigate them.
|
||||
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
||||
and adjust the quality of transition, making it more adaptive.
|
||||
6. Iterative Evaluation:
|
||||
- Test & Refine: Based on the initially deduced conditions and assessed
|
||||
quality, iterate the process to refine and optimize the transition. This
|
||||
might involve tweaking conditions, employing different paths, or changing
|
||||
resources.
|
||||
- Feedback Integration: Use feedback to make improvements and increase the
|
||||
quality of the transition.
|
||||
7. Real-world scenarios often present challenges that may not be captured by
|
||||
models and frameworks. While using the model, maintain an adaptive mindset:
|
||||
- Scenario Exploration: Continuously imagine various possible scenarios,
|
||||
both positive and negative, to prepare for unexpected events.
|
||||
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
||||
process ($L$) if unforeseen challenges arise.
|
||||
- Feedback Integration: Rapidly integrate feedback from actual
|
||||
implementations to adjust the model's application, ensuring relevancy and
|
||||
effectiveness.
|
||||
|
||||
===== TASK =====
|
||||
Given the starting state $A$ and the target state $B$, assuming that a path
|
||||
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
||||
necessary conditions $C$ and the quality $Q$ of the transition?
|
||||
|
||||
===== STARTING STATE $A$ =====
|
||||
{starting_state}
|
||||
|
||||
===== TARGET STATE $B$ =====
|
||||
{target_state}
|
||||
|
||||
{role_with_description_prompt}
|
||||
===== ANSWER TEMPLATE =====
|
||||
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
||||
- Historical & Empirical Analysis:\n<BLANK>/None
|
||||
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
||||
condition <NUM>:
|
||||
<BLANK>.
|
||||
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
||||
square brackets)
|
||||
- Quality Assessment ($Q$) (do not use symbols):
|
||||
<BLANK>.
|
||||
- Iterative Evaluation:\n<BLANK>/None"""
|
||||
|
||||
if role_descriptions_dict is not None:
|
||||
role_names = role_descriptions_dict.keys()
|
||||
role_with_description_prompt = (
|
||||
"===== ROLES WITH DESCRIPTIONS =====\n"
|
||||
+ "\n".join(
|
||||
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
||||
for role_name in role_names
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
else:
|
||||
role_with_description_prompt = ""
|
||||
deduce_prompt = TextPrompt(deduce_prompt)
|
||||
|
||||
deduce = deduce_prompt.format(
|
||||
starting_state=starting_state,
|
||||
target_state=target_state,
|
||||
role_with_description_prompt=role_with_description_prompt,
|
||||
)
|
||||
|
||||
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Deductive Reasoner", content=deduce
|
||||
)
|
||||
|
||||
response = self.step(
|
||||
input_message=conditions_and_quality_generation_msg
|
||||
)
|
||||
|
||||
if response.terminated:
|
||||
raise RuntimeError(
|
||||
"Deduction failed. Error:\n" + f"{response.info}"
|
||||
)
|
||||
msg: BaseMessage = response.msg
|
||||
logger.info(f"Message content:\n{msg.content}")
|
||||
|
||||
# Extract the conditions from the message
|
||||
conditions_dict = {
|
||||
f"condition {i}": cdt.replace("<", "")
|
||||
.replace(">", "")
|
||||
.strip()
|
||||
.strip('\n')
|
||||
for i, cdt in re.findall(
|
||||
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
}
|
||||
|
||||
# Extract the labels from the message
|
||||
labels = [
|
||||
label.strip().strip('\n').strip("\"'")
|
||||
for label in re.findall(
|
||||
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)[0].split(",")
|
||||
]
|
||||
|
||||
# Extract the quality from the message
|
||||
quality = next(
|
||||
q.strip().strip('\n')
|
||||
for q in re.findall(
|
||||
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
||||
r"\n(.+?)- Iterative",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert them into JSON format
|
||||
conditions_and_quality_json: Dict[
|
||||
str, Union[List[str], Dict[str, str]]
|
||||
] = {}
|
||||
conditions_and_quality_json["conditions"] = conditions_dict
|
||||
conditions_and_quality_json["labels"] = labels
|
||||
conditions_and_quality_json["evaluate_quality"] = quality
|
||||
|
||||
return conditions_and_quality_json
|
||||
201
camel/agents/embodied_agent.py
Normal file
201
camel/agents/embodied_agent.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
from camel.interpreters import (
|
||||
BaseInterpreter,
|
||||
InternalPythonInterpreter,
|
||||
SubprocessInterpreter,
|
||||
)
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.utils import print_text_animated
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="EmbodiedAgent")
|
||||
class EmbodiedAgent(ChatAgent):
|
||||
r"""Class for managing conversations of CAMEL Embodied Agents.
|
||||
|
||||
Args:
|
||||
system_message (BaseMessage): The system message for the chat agent.
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
||||
the embodied agent. (default: :obj:`None`)
|
||||
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
||||
execute codes. If `code_interpreter` and `tool_agent` are both
|
||||
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
||||
is `None` and `tool_agents` is not `None`, default to
|
||||
`InternalPythonInterpreter`. (default: :obj:`None`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the logger displayed to the user.
|
||||
(default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: BaseMessage,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
tool_agents: Optional[List[BaseToolAgent]] = None,
|
||||
code_interpreter: Optional[BaseInterpreter] = None,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
self.tool_agents = tool_agents
|
||||
self.code_interpreter: BaseInterpreter
|
||||
if code_interpreter is not None:
|
||||
self.code_interpreter = code_interpreter
|
||||
elif self.tool_agents:
|
||||
self.code_interpreter = InternalPythonInterpreter()
|
||||
else:
|
||||
self.code_interpreter = SubprocessInterpreter()
|
||||
|
||||
if self.tool_agents:
|
||||
system_message = self._set_tool_agents(system_message)
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
super().__init__(
|
||||
system_message=system_message,
|
||||
model=model,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
||||
action_space_prompt = self._get_tool_agents_prompt()
|
||||
result_message = system_message.create_new_instance(
|
||||
content=system_message.content.format(
|
||||
action_space=action_space_prompt
|
||||
)
|
||||
)
|
||||
if self.tool_agents is not None:
|
||||
self.code_interpreter.update_action_space(
|
||||
{tool.name: tool for tool in self.tool_agents}
|
||||
)
|
||||
return result_message
|
||||
|
||||
def _get_tool_agents_prompt(self) -> str:
|
||||
r"""Returns the action space prompt.
|
||||
|
||||
Returns:
|
||||
str: The action space prompt.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return "\n".join(
|
||||
[
|
||||
f"*** {tool.name} ***:\n {tool.description}"
|
||||
for tool in self.tool_agents
|
||||
]
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_tool_agent_names(self) -> List[str]:
|
||||
r"""Returns the names of tool agents.
|
||||
|
||||
Returns:
|
||||
List[str]: The names of tool agents.
|
||||
"""
|
||||
if self.tool_agents is not None:
|
||||
return [tool.name for tool in self.tool_agents]
|
||||
else:
|
||||
return []
|
||||
|
||||
# ruff: noqa: E501
|
||||
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
||||
r"""Performs a step in the conversation.
|
||||
|
||||
Args:
|
||||
input_message (BaseMessage): The input message.
|
||||
|
||||
Returns:
|
||||
ChatAgentResponse: A struct containing the output messages,
|
||||
a boolean indicating whether the chat session has terminated,
|
||||
and information about the chat session.
|
||||
"""
|
||||
response = super().step(input_message)
|
||||
|
||||
if response.msgs is None or len(response.msgs) == 0:
|
||||
raise RuntimeError("Got None output messages.")
|
||||
if response.terminated:
|
||||
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
||||
|
||||
# NOTE: Only single output messages are supported
|
||||
explanations, codes = response.msg.extract_text_and_code_prompts()
|
||||
|
||||
if self.verbose:
|
||||
for explanation, code in zip(explanations, codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanation}"
|
||||
)
|
||||
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
||||
|
||||
if len(explanations) > len(codes):
|
||||
print_text_animated(
|
||||
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
||||
)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if codes is not None:
|
||||
try:
|
||||
content = "\n> Executed Results:\n"
|
||||
for block_idx, code in enumerate(codes):
|
||||
executed_output = self.code_interpreter.run(
|
||||
code, code.code_type
|
||||
)
|
||||
content += (
|
||||
f"Executing code block {block_idx}: {{\n"
|
||||
+ executed_output
|
||||
+ "}\n"
|
||||
)
|
||||
except InterruptedError as e:
|
||||
content = (
|
||||
f"\n> Running code fail: {e}\n"
|
||||
"Please regenerate the code."
|
||||
)
|
||||
|
||||
# TODO: Handle errors
|
||||
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
||||
message = BaseMessage(
|
||||
input_message.role_name,
|
||||
input_message.role_type,
|
||||
input_message.meta_dict,
|
||||
content,
|
||||
)
|
||||
return ChatAgentResponse(
|
||||
msgs=[message],
|
||||
terminated=response.terminated,
|
||||
info=response.info,
|
||||
)
|
||||
278
camel/agents/knowledge_graph_agent.py
Normal file
278
camel/agents/knowledge_graph_agent.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.storages.graph_storages.graph_element import (
|
||||
GraphElement,
|
||||
Node,
|
||||
Relationship,
|
||||
)
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
text_prompt = """
|
||||
You are tasked with extracting nodes and relationships from given content and
|
||||
structures them into Node and Relationship objects. Here's the outline of what
|
||||
you needs to do:
|
||||
|
||||
Content Extraction:
|
||||
You should be able to process input content and identify entities mentioned
|
||||
within it.
|
||||
Entities can be any noun phrases or concepts that represent distinct entities
|
||||
in the context of the given content.
|
||||
|
||||
Node Extraction:
|
||||
For each identified entity, you should create a Node object.
|
||||
Each Node object should have a unique identifier (id) and a type (type).
|
||||
Additional properties associated with the node can also be extracted and
|
||||
stored.
|
||||
|
||||
Relationship Extraction:
|
||||
You should identify relationships between entities mentioned in the content.
|
||||
For each relationship, create a Relationship object.
|
||||
A Relationship object should have a subject (subj) and an object (obj) which
|
||||
are Node objects representing the entities involved in the relationship.
|
||||
Each relationship should also have a type (type), and additional properties if
|
||||
applicable.
|
||||
|
||||
Output Formatting:
|
||||
The extracted nodes and relationships should be formatted as instances of the
|
||||
provided Node and Relationship classes.
|
||||
Ensure that the extracted data adheres to the structure defined by the classes.
|
||||
Output the structured data in a format that can be easily validated against
|
||||
the provided code.
|
||||
Do not wrap the output in lists or dictionaries, provide the Node and
|
||||
Relationship with unique identifiers.
|
||||
Strictly follow the format provided in the example output, do not add any
|
||||
additional information.
|
||||
|
||||
|
||||
Instructions for you:
|
||||
Read the provided content thoroughly.
|
||||
Identify distinct entities mentioned in the content and categorize them as
|
||||
nodes.
|
||||
Determine relationships between these entities and represent them as directed
|
||||
relationships.
|
||||
Provide the extracted nodes and relationships in the specified format below.
|
||||
Example for you:
|
||||
|
||||
Example Content:
|
||||
"John works at XYZ Corporation. He is a software engineer. The company is
|
||||
located in New York City."
|
||||
|
||||
Expected Output:
|
||||
|
||||
Nodes:
|
||||
|
||||
Node(id='John', type='Person')
|
||||
Node(id='XYZ Corporation', type='Organization')
|
||||
Node(id='New York City', type='Location')
|
||||
|
||||
Relationships:
|
||||
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
||||
Corporation', type='Organization'), type='WorksAt')
|
||||
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
||||
type='Location'), type='ResidesIn')
|
||||
|
||||
===== TASK =====
|
||||
Please extracts nodes and relationships from given content and structures them
|
||||
into Node and Relationship objects.
|
||||
|
||||
{task}
|
||||
"""
|
||||
|
||||
|
||||
@track_agent(name="KnowledgeGraphAgent")
|
||||
class KnowledgeGraphAgent(ChatAgent):
|
||||
r"""An agent that can extract node and relationship information for
|
||||
different entities from given `Element` content.
|
||||
|
||||
Attributes:
|
||||
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
||||
relationship information for different entities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
r"""Initialize the `KnowledgeGraphAgent`.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
system_message = BaseMessage(
|
||||
role_name="Graphify",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="Your mission is to transform unstructured content "
|
||||
"into structured graph data. Extract nodes and relationships with "
|
||||
"precision, and let the connections unfold. Your graphs will "
|
||||
"illuminate the hidden connections within the chaos of "
|
||||
"information.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
element: "Element",
|
||||
parse_graph_elements: bool = False,
|
||||
prompt: Optional[str] = None,
|
||||
) -> Union[str, GraphElement]:
|
||||
r"""Run the agent to extract node and relationship information.
|
||||
|
||||
Args:
|
||||
element (Element): The input element.
|
||||
parse_graph_elements (bool, optional): Whether to parse into
|
||||
`GraphElement`. Defaults to `False`.
|
||||
prompt (str, optional): The custom prompt to be used.
|
||||
Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Union[str, GraphElement]: The extracted node and relationship
|
||||
information. If `parse_graph_elements` is `True` then return
|
||||
`GraphElement`, else return `str`.
|
||||
"""
|
||||
self.reset()
|
||||
self.element = element
|
||||
|
||||
# Use the provided prompt or fall back to the default text_prompt
|
||||
final_prompt = prompt if prompt is not None else text_prompt
|
||||
|
||||
knowledge_graph_prompt = TextPrompt(final_prompt)
|
||||
knowledge_graph_generation = knowledge_graph_prompt.format(
|
||||
task=str(element)
|
||||
)
|
||||
|
||||
response = self.step(input_message=knowledge_graph_generation)
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if parse_graph_elements:
|
||||
content = self._parse_graph_elements(content)
|
||||
|
||||
return content
|
||||
|
||||
def _validate_node(self, node: Node) -> bool:
|
||||
r"""Validate if the object is a valid Node.
|
||||
|
||||
Args:
|
||||
node (Node): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Node, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(node, Node)
|
||||
and isinstance(node.id, (str, int))
|
||||
and isinstance(node.type, str)
|
||||
)
|
||||
|
||||
def _validate_relationship(self, relationship: Relationship) -> bool:
|
||||
r"""Validate if the object is a valid Relationship.
|
||||
|
||||
Args:
|
||||
relationship (Relationship): Object to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is a valid Relationship, False otherwise.
|
||||
"""
|
||||
return (
|
||||
isinstance(relationship, Relationship)
|
||||
and self._validate_node(relationship.subj)
|
||||
and self._validate_node(relationship.obj)
|
||||
and isinstance(relationship.type, str)
|
||||
)
|
||||
|
||||
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
||||
r"""Parses graph elements from given content.
|
||||
|
||||
Args:
|
||||
input_string (str): The input content.
|
||||
|
||||
Returns:
|
||||
GraphElement: The parsed graph elements.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Regular expressions to extract nodes and relationships
|
||||
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
||||
rel_pattern = (
|
||||
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
||||
r"obj=Node\(id='(.*?)', type='(.*?)'\), "
|
||||
r"type='(.*?)'(?:, timestamp='(.*?)')?\)"
|
||||
)
|
||||
|
||||
nodes = {}
|
||||
relationships = []
|
||||
|
||||
# Extract nodes
|
||||
for match in re.finditer(node_pattern, input_string):
|
||||
id, type = match.groups()
|
||||
properties = {'source': 'agent_created'}
|
||||
if id not in nodes:
|
||||
node = Node(id=id, type=type, properties=properties)
|
||||
if self._validate_node(node):
|
||||
nodes[id] = node
|
||||
|
||||
# Extract relationships
|
||||
for match in re.finditer(rel_pattern, input_string):
|
||||
groups = match.groups()
|
||||
if len(groups) == 6:
|
||||
subj_id, subj_type, obj_id, obj_type, rel_type, timestamp = (
|
||||
groups
|
||||
)
|
||||
else:
|
||||
subj_id, subj_type, obj_id, obj_type, rel_type = groups
|
||||
timestamp = None
|
||||
properties = {'source': 'agent_created'}
|
||||
if subj_id in nodes and obj_id in nodes:
|
||||
subj = nodes[subj_id]
|
||||
obj = nodes[obj_id]
|
||||
relationship = Relationship(
|
||||
subj=subj,
|
||||
obj=obj,
|
||||
type=rel_type,
|
||||
timestamp=timestamp,
|
||||
properties=properties,
|
||||
)
|
||||
if self._validate_relationship(relationship):
|
||||
relationships.append(relationship)
|
||||
|
||||
return GraphElement(
|
||||
nodes=list(nodes.values()),
|
||||
relationships=relationships,
|
||||
source=self.element,
|
||||
)
|
||||
117
camel/agents/multi_hop_generator_agent.py
Normal file
117
camel/agents/multi_hop_generator_agent.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from camel.agents.programmed_agent_instruction import (
|
||||
ProgrammableChatAgent,
|
||||
ProgrammedAgentInstructionResult,
|
||||
programmable_capability,
|
||||
)
|
||||
from camel.datagen.source2synth.models import (
|
||||
ContextPrompt,
|
||||
MultiHopQA,
|
||||
)
|
||||
from camel.messages import BaseMessage
|
||||
|
||||
|
||||
class MultiHopGeneratorAgent(ProgrammableChatAgent):
|
||||
r"""An agent specialized in generating multi-hop question-answer pairs.
|
||||
|
||||
This agent is designed to create complex questions that require multiple
|
||||
steps of reasoning to answer. It analyzes context to identify related
|
||||
facts and generates questions that require connecting these facts
|
||||
logically.
|
||||
|
||||
Attributes:
|
||||
model_config (ConfigDict): Configuration for model behavior.
|
||||
system_message (BaseMessage): System message defining agent's role and
|
||||
instructions.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
r"""Initialize the MultiHopGeneratorAgent.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments to pass to parent
|
||||
class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
system_text: str = textwrap.dedent(
|
||||
"""\
|
||||
You are an expert at generating
|
||||
multi-hop question-answer pairs.
|
||||
For each context, you should:
|
||||
1. Identify multiple related facts or pieces of information
|
||||
2. Create questions that require reasoning across these multiple pieces
|
||||
3. Ensure the reasoning chain is clear and logical
|
||||
4. Generate questions that require at least 2-3 steps of reasoning
|
||||
5. Include the reasoning steps in the answer
|
||||
|
||||
Give your response with this information:
|
||||
Question: [Complex question requiring multiple reasoning steps]
|
||||
Reasoning Steps:
|
||||
1. [First reasoning step]
|
||||
2. [Second reasoning step]
|
||||
3. [Final reasoning step]
|
||||
Answer: [Final answer]
|
||||
Supporting Facts: [List of relevant text segments used]
|
||||
""" # noqa: E501
|
||||
)
|
||||
self._system_message = BaseMessage.make_assistant_message(
|
||||
role_name='Assistant', content=system_text
|
||||
)
|
||||
|
||||
@programmable_capability
|
||||
def generate_multi_hop_qa(
|
||||
self, context: str
|
||||
) -> ProgrammedAgentInstructionResult[MultiHopQA]:
|
||||
r"""Generate a multi-hop question-answer pair from given context.
|
||||
|
||||
Args:
|
||||
context (str): The input text context to generate QA from.
|
||||
|
||||
Returns:
|
||||
ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
|
||||
generated question, reasoning steps, answer, and supporting
|
||||
facts.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the agent fails to generate a response.
|
||||
"""
|
||||
context_prompt = ContextPrompt(
|
||||
main_context=context, related_contexts=None
|
||||
)
|
||||
|
||||
user_message = BaseMessage.make_user_message(
|
||||
content=context_prompt.model_dump_json(), role_name="User"
|
||||
)
|
||||
response = self.step(
|
||||
input_message=user_message, response_format=MultiHopQA
|
||||
)
|
||||
value = MultiHopQA.model_validate_json(response.msgs[0].content)
|
||||
|
||||
if response.msgs:
|
||||
return ProgrammedAgentInstructionResult(
|
||||
user_message=user_message,
|
||||
agent_message=response.msgs[0],
|
||||
value=value,
|
||||
)
|
||||
raise RuntimeError("No response from agent")
|
||||
203
camel/agents/programmed_agent_instruction.py
Normal file
203
camel/agents/programmed_agent_instruction.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import abc
|
||||
import threading
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class ProgrammableAgentRequirement(Enum):
|
||||
r"""Requirements for programmable agent state.
|
||||
|
||||
Defines the possible requirements that can be used to repair the state
|
||||
of a programmable agent.
|
||||
|
||||
Attributes:
|
||||
LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
|
||||
conversation was not from the user.
|
||||
"""
|
||||
|
||||
LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
|
||||
|
||||
|
||||
class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
|
||||
r"""Result of a programmable agent instruction execution.
|
||||
|
||||
Contains the messages exchanged during execution and the computed value.
|
||||
The value type is specified by the generic type parameter T.
|
||||
|
||||
Attributes:
|
||||
user_message (BaseMessage): The message sent by the user.
|
||||
agent_message (BaseMessage): The message sent by the agent.
|
||||
value (T): The computed result value of type T.
|
||||
"""
|
||||
|
||||
user_message: BaseMessage
|
||||
agent_message: BaseMessage
|
||||
value: T
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AbstractProgrammableAgent(abc.ABC):
|
||||
r"""Abstract class for a programmable agent.
|
||||
|
||||
A programmable agent is an agent that can be programmed to perform a
|
||||
specific function or task. This class defines the interface for a
|
||||
programmable agent.
|
||||
|
||||
These methods should be implemented in order to ensure the agent supports
|
||||
the necessary guarantees to enable a programming interface while
|
||||
maintaining compatibility in a multi-agent system.
|
||||
|
||||
A programmable agent is responsible for providing and maintaining a
|
||||
programming interface for its functionality.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run_atomic(
|
||||
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
||||
) -> ProgrammedAgentInstructionResult[T]:
|
||||
r"""Run an atomic operation on the agent.
|
||||
|
||||
An atomic operation is an operation that is guaranteed to
|
||||
be executed without interruption by any other operation.
|
||||
|
||||
Args:
|
||||
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
||||
operation to execute atomically.
|
||||
|
||||
Returns:
|
||||
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an operation is already in progress.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
||||
r"""Repair the state of the agent.
|
||||
|
||||
Agents may have other non-atomic interfaces, such as a user interface,
|
||||
or chat between other agents. This method should restore the agent to
|
||||
a state where it can perform operations according to the specified
|
||||
requirement.
|
||||
|
||||
Args:
|
||||
requirement (ProgrammableAgentRequirement): The requirement to
|
||||
repair the state for.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def programmable_capability(
|
||||
func: Callable[..., ProgrammedAgentInstructionResult[T]],
|
||||
) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
|
||||
r"""Decorator for programmable agent capabilities.
|
||||
|
||||
This decorator ensures that the decorated method is executed atomically
|
||||
and maintains the agent's state guarantees.
|
||||
|
||||
Args:
|
||||
func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
|
||||
to decorate.
|
||||
|
||||
Returns:
|
||||
Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
|
||||
method that ensures atomic execution.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ProgrammedAgentInstructionResult[T]:
|
||||
return self.run_atomic(lambda: func(self, *args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
|
||||
r"""A chat agent that can be programmed to perform specific tasks.
|
||||
|
||||
Provides a default implementation of atomic execution using threading locks
|
||||
and basic state tracking for message roles. Implementing classes need to
|
||||
provide specific repair logic for their use cases.
|
||||
|
||||
Attributes:
|
||||
_operation_lock (threading.Lock): Lock for ensuring atomic operations.
|
||||
_last_message_role (Optional[str]): Role of the last message in the
|
||||
conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
r"""Initialize the ProgrammableChatAgent.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments to pass to parent
|
||||
class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._operation_lock = threading.Lock()
|
||||
self._last_message_role: Optional[str] = None
|
||||
|
||||
def run_atomic(
|
||||
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
||||
) -> ProgrammedAgentInstructionResult[T]:
|
||||
r"""Run an atomic operation on the agent.
|
||||
|
||||
Ensures thread-safe execution of the callback function by using a lock.
|
||||
|
||||
Args:
|
||||
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
||||
operation to execute atomically.
|
||||
|
||||
Returns:
|
||||
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an operation is already in progress.
|
||||
"""
|
||||
if not self._operation_lock.acquire(blocking=False):
|
||||
raise RuntimeError("Operation already in progress")
|
||||
|
||||
try:
|
||||
result = callback()
|
||||
self._last_message_role = result.agent_message.role_name
|
||||
return result
|
||||
finally:
|
||||
self._operation_lock.release()
|
||||
|
||||
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
||||
r"""Repair the state of the agent.
|
||||
|
||||
Implements basic state repair for message role requirements.
|
||||
|
||||
Args:
|
||||
requirement (ProgrammableAgentRequirement): The requirement to
|
||||
repair the state for.
|
||||
"""
|
||||
if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
|
||||
if self._last_message_role == "user":
|
||||
raise NotImplementedError(
|
||||
"Must implement repair for LAST_MESSAGE_NOT_USER"
|
||||
)
|
||||
579
camel/agents/repo_agent.py
Normal file
579
camel/agents/repo_agent.py
Normal file
@@ -0,0 +1,579 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
from string import Template
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from github.MainClass import Github
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend, ModelFactory
|
||||
from camel.responses import ChatAgentResponse
|
||||
from camel.retrievers import VectorRetriever
|
||||
from camel.types import (
|
||||
ModelPlatformType,
|
||||
ModelType,
|
||||
OpenAIBackendRole,
|
||||
RoleType,
|
||||
)
|
||||
from camel.utils import track_agent
|
||||
from camel.utils.chunker import CodeChunker
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ProcessingMode(Enum):
|
||||
FULL_CONTEXT = auto()
|
||||
RAG = auto()
|
||||
|
||||
|
||||
class GitHubFile(BaseModel):
|
||||
r"""Model to hold GitHub file information.
|
||||
|
||||
Attributes:
|
||||
content (str): The content of the GitHub text.
|
||||
file_path (str): The path of the file.
|
||||
html_url (str): The actual url of the file.
|
||||
"""
|
||||
|
||||
content: str
|
||||
file_path: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class RepositoryInfo(BaseModel):
|
||||
r"""Model to hold GitHub repository information.
|
||||
|
||||
Attributes:
|
||||
repo_name (str): The full name of the repository.
|
||||
repo_url (str): The URL of the repository.
|
||||
contents (list): A list to hold the repository contents.
|
||||
"""
|
||||
|
||||
repo_name: str
|
||||
repo_url: str
|
||||
contents: List[GitHubFile] = []
|
||||
|
||||
|
||||
@track_agent(name="RepoAgent")
|
||||
class RepoAgent(ChatAgent):
|
||||
r"""A specialized agent designed to interact with GitHub repositories for
|
||||
code generation tasks.
|
||||
The RepoAgent enhances a base ChatAgent by integrating context from
|
||||
one or more GitHub repositories. It supports two processing modes:
|
||||
- FULL_CONTEXT: loads and injects full repository content into the
|
||||
prompt.
|
||||
- RAG (Retrieval-Augmented Generation): retrieves relevant
|
||||
code/documentation chunks using a vector store when context
|
||||
length exceeds a specified token limit.
|
||||
|
||||
Attributes:
|
||||
vector_retriever (VectorRetriever): Retriever used to
|
||||
perform semantic search in RAG mode. Required if repo content
|
||||
exceeds context limit.
|
||||
system_message (Optional[str]): The system message
|
||||
for the chat agent. (default: :str:`"You are a code assistant
|
||||
with repo context."`)
|
||||
repo_paths (Optional[List[str]]): List of GitHub repository URLs to
|
||||
load during initialization. (default: :obj:`None`)
|
||||
model (BaseModelBackend): The model backend to use for generating
|
||||
responses. (default: :obj:`ModelPlatformType.DEFAULT`
|
||||
with `ModelType.DEFAULT`)
|
||||
max_context_tokens (Optional[int]): Maximum number of tokens allowed
|
||||
before switching to RAG mode. (default: :obj:`2000`)
|
||||
github_auth_token (Optional[str]): GitHub personal access token
|
||||
for accessing private or rate-limited repositories. (default:
|
||||
:obj:`None`)
|
||||
chunk_size (Optional[int]): Maximum number of characters per code chunk
|
||||
when indexing files for RAG. (default: :obj:`8192`)
|
||||
top_k (int): Number of top-matching chunks to retrieve from the vector
|
||||
store in RAG mode. (default: :obj:`5`)
|
||||
similarity (Optional[float]): Minimum similarity score required to
|
||||
include a chunk in the RAG context. (default: :obj:`0.6`)
|
||||
collection_name (Optional[str]): Name of the vector database
|
||||
collection to use for storing and retrieving chunks. (default:
|
||||
:obj:`None`)
|
||||
**kwargs: Inherited from ChatAgent
|
||||
|
||||
Note:
|
||||
The current implementation of RAG mode requires using Qdrant as the
|
||||
vector storage backend. The VectorRetriever defaults to QdrantStorage
|
||||
if no storage is explicitly provided. Other vector storage backends
|
||||
are not currently supported for the RepoAgent's RAG functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_retriever: VectorRetriever,
|
||||
system_message: Optional[
|
||||
str
|
||||
] = "You are a code assistant with repo context.",
|
||||
repo_paths: Optional[List[str]] = None,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
max_context_tokens: int = 2000,
|
||||
github_auth_token: Optional[str] = None,
|
||||
chunk_size: Optional[int] = 8192,
|
||||
top_k: Optional[int] = 5,
|
||||
similarity: Optional[float] = 0.6,
|
||||
collection_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if model is None:
|
||||
model = ModelFactory.create(
|
||||
model_platform=ModelPlatformType.DEFAULT,
|
||||
model_type=ModelType.DEFAULT,
|
||||
)
|
||||
|
||||
super().__init__(system_message=system_message, model=model, **kwargs)
|
||||
self.max_context_tokens = max_context_tokens
|
||||
self.vector_retriever = vector_retriever
|
||||
self.github_auth_token = github_auth_token
|
||||
self.chunk_size = chunk_size
|
||||
self.num_tokens = 0
|
||||
self.processing_mode = ProcessingMode.FULL_CONTEXT
|
||||
self.top_k = top_k
|
||||
self.similarity = similarity
|
||||
self.collection_name = collection_name
|
||||
self.prompt_template = Template(
|
||||
"$type: $repo\n"
|
||||
"You are an AI coding assistant. "
|
||||
"Your task is to generate code based on provided GitHub "
|
||||
"repositories. \n"
|
||||
"### Instructions: \n1. **Analyze the Repositories**: "
|
||||
"Identify which repositories contain relevant "
|
||||
"information for the user's request. Ignore unrelated ones.\n"
|
||||
"2. **Extract Context**: Use code, documentation, "
|
||||
"dependencies, and tests to understand functionality.\n"
|
||||
"3. **Generate Code**: Create clean, efficient, and "
|
||||
"well-structured code that aligns with relevant repositories. \n"
|
||||
"4. **Justify Output**: Explain which repositories "
|
||||
"influenced your solution and why others were ignored."
|
||||
"\n If the repositories lack necessary details, "
|
||||
"infer best practices and suggest improvements.\n"
|
||||
"Now, analyze the repositories and generate the "
|
||||
"required code."
|
||||
)
|
||||
self.full_text = ""
|
||||
self.chunker = CodeChunker(chunk_size=chunk_size or 8192)
|
||||
self.repos: List[RepositoryInfo] = []
|
||||
if repo_paths:
|
||||
self.repos = self.load_repositories(repo_paths)
|
||||
if len(self.repos) > 0:
|
||||
self.construct_full_text()
|
||||
self.num_tokens = self.count_tokens()
|
||||
if not self.check_switch_mode():
|
||||
self.update_memory(
|
||||
message=BaseMessage.make_user_message(
|
||||
role_name=RoleType.USER.value,
|
||||
content=self.full_text,
|
||||
),
|
||||
role=OpenAIBackendRole.SYSTEM,
|
||||
)
|
||||
|
||||
def parse_url(self, url: str) -> Tuple[str, str]:
|
||||
r"""Parse the GitHub URL and return the (owner, repo_name) tuple.
|
||||
|
||||
Args:
|
||||
url (str): The URL to be parsed.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The (owner, repo_name) tuple.
|
||||
"""
|
||||
try:
|
||||
url_path = url.replace("https://github.com/", "")
|
||||
parts = url_path.split("/")
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Incorrect GitHub repo URL format.")
|
||||
else:
|
||||
return parts[0], parts[1]
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing URL: {e}")
|
||||
raise Exception(e)
|
||||
|
||||
def load_repositories(
|
||||
self,
|
||||
repo_urls: List[str],
|
||||
) -> List[RepositoryInfo]:
|
||||
r"""Load the content of a GitHub repository.
|
||||
|
||||
Args:
|
||||
repo_urls (str): The list of Repo URLs.
|
||||
|
||||
Returns:
|
||||
List[RepositoryInfo]: A list of objects containing information
|
||||
about the all repositories, including the contents.
|
||||
"""
|
||||
from github.MainClass import Github
|
||||
|
||||
github_client = Github(self.github_auth_token)
|
||||
res = []
|
||||
|
||||
for repo_url in repo_urls:
|
||||
try:
|
||||
res.append(self.load_repository(repo_url, github_client))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading repository: {e}")
|
||||
raise Exception(e)
|
||||
time.sleep(1)
|
||||
logger.info(f"Successfully loaded {len(res)} repositories.")
|
||||
return res
|
||||
|
||||
def load_repository(
|
||||
self,
|
||||
repo_url: str,
|
||||
github_client: "Github",
|
||||
) -> RepositoryInfo:
|
||||
r"""Load the content of a GitHub repository.
|
||||
|
||||
Args:
|
||||
repo_urls (str): The Repo URL to be loaded.
|
||||
github_client (GitHub): The established GitHub client.
|
||||
|
||||
Returns:
|
||||
RepositoryInfo: The object containing information
|
||||
about the repository, including the contents.
|
||||
"""
|
||||
from github.ContentFile import ContentFile
|
||||
|
||||
try:
|
||||
owner, repo_name = self.parse_url(repo_url)
|
||||
repo = github_client.get_repo(f"{owner}/{repo_name}")
|
||||
contents = repo.get_contents("")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading repository: {e}")
|
||||
raise Exception(e)
|
||||
|
||||
info = RepositoryInfo(
|
||||
repo_name=repo.full_name,
|
||||
repo_url=repo.html_url,
|
||||
contents=[],
|
||||
)
|
||||
|
||||
# Create a list to process repository contents
|
||||
content_list: List[ContentFile] = []
|
||||
if isinstance(contents, list):
|
||||
content_list = contents
|
||||
else:
|
||||
# Handle single ContentFile case
|
||||
content_list = [contents]
|
||||
|
||||
while content_list:
|
||||
file = content_list.pop(0)
|
||||
if file.type == "file":
|
||||
if any(
|
||||
file.path.endswith(ext)
|
||||
for ext in [
|
||||
".png",
|
||||
".jpg",
|
||||
".pdf",
|
||||
".zip",
|
||||
".gitignore",
|
||||
".mp4",
|
||||
".avi",
|
||||
".mov",
|
||||
".mp3",
|
||||
".wav",
|
||||
".tar",
|
||||
".gz",
|
||||
".7z",
|
||||
".rar",
|
||||
".iso",
|
||||
".gif",
|
||||
".docx",
|
||||
]
|
||||
):
|
||||
logger.info(f"Skipping binary file: {file.path}")
|
||||
continue
|
||||
try:
|
||||
file_obj = repo.get_contents(file.path)
|
||||
|
||||
# Handle file_obj which could be a single ContentFile or a
|
||||
# list
|
||||
if isinstance(file_obj, list):
|
||||
if not file_obj: # Skip empty lists
|
||||
continue
|
||||
file_obj = file_obj[
|
||||
0
|
||||
] # Take the first item if it's a list
|
||||
|
||||
if getattr(file_obj, "encoding", None) != "base64":
|
||||
logger.warning(
|
||||
f"Skipping file with unsupported "
|
||||
f"encoding: {file.path}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
content_bytes = file_obj.decoded_content
|
||||
file_content = content_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(f"Skipping non-UTF-8 file: {file.path}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to decode file content at "
|
||||
f"{file.path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
github_file = GitHubFile(
|
||||
content=file_content,
|
||||
file_path=f"{owner}/{repo_name}/{file.path}",
|
||||
html_url=file.html_url,
|
||||
)
|
||||
info.contents.append(github_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {e}")
|
||||
raise Exception(e)
|
||||
logger.info(f"Successfully loaded file: {file.path}")
|
||||
elif file.type == "dir":
|
||||
dir_contents = repo.get_contents(file.path)
|
||||
# Handle dir_contents which could be a single ContentFile or a
|
||||
# list
|
||||
if isinstance(dir_contents, list):
|
||||
content_list.extend(dir_contents)
|
||||
else:
|
||||
content_list.append(dir_contents)
|
||||
return info
|
||||
|
||||
def count_tokens(self) -> int:
|
||||
r"""To count the tokens that's currently in the memory
|
||||
|
||||
Returns:
|
||||
int: The number of tokens
|
||||
"""
|
||||
counter = self.model_backend.token_counter
|
||||
content_token_count = counter.count_tokens_from_messages(
|
||||
messages=[
|
||||
BaseMessage.make_user_message(
|
||||
role_name=RoleType.USER.value,
|
||||
content=self.full_text,
|
||||
).to_openai_message(OpenAIBackendRole.USER)
|
||||
]
|
||||
)
|
||||
return content_token_count
|
||||
|
||||
def construct_full_text(self):
|
||||
r"""Construct full context text from repositories by concatenation."""
|
||||
repo_texts = [
|
||||
{"content": f.content, "path": f.file_path}
|
||||
for repo in self.repos
|
||||
for f in repo.contents
|
||||
]
|
||||
self.full_text = self.prompt_template.safe_substitute(
|
||||
type="Repository",
|
||||
repo="\n".join(
|
||||
f"{repo['path']}\n{repo['content']}" for repo in repo_texts
|
||||
),
|
||||
)
|
||||
|
||||
def add_repositories(self, repo_urls: List[str]):
|
||||
r"""Add a GitHub repository to the list of repositories.
|
||||
|
||||
Args:
|
||||
repo_urls (str): The Repo URL to be added.
|
||||
"""
|
||||
new_repos = self.load_repositories(repo_urls)
|
||||
self.repos.extend(new_repos)
|
||||
self.construct_full_text()
|
||||
self.num_tokens = self.count_tokens()
|
||||
if self.processing_mode == ProcessingMode.RAG:
|
||||
for repo in new_repos:
|
||||
for f in repo.contents:
|
||||
self.vector_retriever.process(
|
||||
content=f.content,
|
||||
should_chunk=True,
|
||||
extra_info={"file_path": f.file_path},
|
||||
chunker=self.chunker,
|
||||
)
|
||||
else:
|
||||
self.check_switch_mode()
|
||||
|
||||
def check_switch_mode(self) -> bool:
|
||||
r"""Check if the current context exceeds the context window; if so,
|
||||
switch to RAG mode.
|
||||
|
||||
Returns:
|
||||
bool: True if the mode was switched, False otherwise.
|
||||
"""
|
||||
if self.processing_mode == ProcessingMode.RAG:
|
||||
return False
|
||||
|
||||
if self.num_tokens > self.max_context_tokens:
|
||||
if not self.vector_retriever:
|
||||
logger.warning(
|
||||
f"Token count ({self.num_tokens}) exceeds limit "
|
||||
f"({self.max_context_tokens}). "
|
||||
"Either reduce repository size or provide a "
|
||||
"VectorRetriever."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Switching to RAG mode and indexing repositories...")
|
||||
self.processing_mode = ProcessingMode.RAG
|
||||
for repo in self.repos:
|
||||
for f in repo.contents:
|
||||
self.vector_retriever.process(
|
||||
content=f.content,
|
||||
should_chunk=True,
|
||||
extra_info={"file_path": f.file_path},
|
||||
chunker=self.chunker,
|
||||
)
|
||||
self._system_message = None
|
||||
self.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def step(
|
||||
self, input_message: Union[BaseMessage, str], *args, **kwargs
|
||||
) -> ChatAgentResponse:
|
||||
r"""Overrides `ChatAgent.step()` to first retrieve relevant context
|
||||
from the vector store before passing the input to the language model.
|
||||
"""
|
||||
if (
|
||||
self.processing_mode == ProcessingMode.RAG
|
||||
and self.vector_retriever
|
||||
):
|
||||
if isinstance(input_message, BaseMessage):
|
||||
user_query = input_message.content
|
||||
else:
|
||||
user_query = input_message
|
||||
retrieved_content = []
|
||||
retries = 1
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
raw_rag_content = self.vector_retriever.query(
|
||||
query=user_query,
|
||||
top_k=self.top_k or 5,
|
||||
similarity_threshold=self.similarity or 0.6,
|
||||
)
|
||||
# Remove duplicates and retrieve the whole file
|
||||
paths = []
|
||||
for record in raw_rag_content:
|
||||
file_path = record["extra_info"]["file_path"]
|
||||
if file_path not in paths:
|
||||
retrieved_content.append(
|
||||
{
|
||||
"content": self.search_by_file_path(
|
||||
file_path
|
||||
),
|
||||
"similarity": record["similarity score"],
|
||||
}
|
||||
)
|
||||
paths.append(file_path)
|
||||
|
||||
retrieved_content = sorted(
|
||||
retrieved_content,
|
||||
key=lambda x: x["similarity"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
full_prompt = self.prompt_template.safe_substitute(
|
||||
type="Retrieved code",
|
||||
repo="\n".join(
|
||||
[record["content"] for record in retrieved_content]
|
||||
),
|
||||
)
|
||||
|
||||
new_query = user_query + "\n" + full_prompt
|
||||
if isinstance(input_message, BaseMessage):
|
||||
input_message.content = new_query
|
||||
else:
|
||||
input_message = BaseMessage.make_user_message(
|
||||
role_name="User", content=new_query
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
if attempt < retries - 1:
|
||||
sleep_time = 2**attempt
|
||||
logger.info(
|
||||
f"Retrying qdrant query in {sleep_time} seconds..."
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to query qdrant record after {retries} "
|
||||
"attempts."
|
||||
)
|
||||
|
||||
return super().step(input_message, *args, **kwargs)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self.processing_mode == ProcessingMode.FULL_CONTEXT:
|
||||
message = BaseMessage.make_user_message(
|
||||
role_name=RoleType.USER.value,
|
||||
content=self.full_text,
|
||||
)
|
||||
self.update_memory(message, OpenAIBackendRole.SYSTEM)
|
||||
else:
|
||||
self.num_tokens = 0
|
||||
|
||||
def search_by_file_path(self, file_path: str) -> str:
|
||||
r"""Search for all payloads in the vector database where
|
||||
file_path matches the given value (the same file),
|
||||
then sort by piece_num and concatenate text fields to return a
|
||||
complete result.
|
||||
|
||||
Args:
|
||||
file_path (str): The `file_path` value to filter the payloads.
|
||||
|
||||
Returns:
|
||||
str: A concatenated string of the `text` fields sorted by
|
||||
`piece_num`.
|
||||
"""
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
try:
|
||||
storage_instance = self.vector_retriever.storage
|
||||
collection_name = (
|
||||
self.collection_name or storage_instance.collection_name # type: ignore[attr-defined]
|
||||
)
|
||||
source_data, _ = storage_instance.client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1000,
|
||||
scroll_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="extra_info.file_path",
|
||||
match=MatchValue(value=file_path),
|
||||
)
|
||||
]
|
||||
),
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during database initialization or scroll: {e}"
|
||||
)
|
||||
raise Exception(e)
|
||||
|
||||
results = []
|
||||
for point in source_data:
|
||||
payload = point.payload
|
||||
piece_num = payload["metadata"]["piece_num"]
|
||||
text = payload["text"]
|
||||
if piece_num is not None and text:
|
||||
results.append({"piece_num": piece_num, "text": text})
|
||||
|
||||
sorted_results = sorted(results, key=lambda x: x["piece_num"])
|
||||
full_doc = "\n".join([item["text"] for item in sorted_results])
|
||||
|
||||
return full_doc
|
||||
141
camel/agents/role_assignment_agent.py
Normal file
141
camel/agents/role_assignment_agent.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import re
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="RoleAssignmentAgent")
|
||||
class RoleAssignmentAgent(ChatAgent):
|
||||
r"""An agent that generates role names based on the task prompt.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
|
||||
Attributes:
|
||||
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
||||
role names.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Role Assigner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You assign roles based on tasks.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
num_roles: int = 2,
|
||||
) -> Dict[str, str]:
|
||||
r"""Generate role names based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt
|
||||
for the task based on which the roles are to be generated.
|
||||
num_roles (int, optional): The number of roles to generate.
|
||||
(default: :obj:`2`)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping role names to their
|
||||
descriptions.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
||||
f"Domain expert {i + 1}: <BLANK>\n"
|
||||
f"Associated competencies, characteristics, duties "
|
||||
f"and workflows: <BLANK>. End."
|
||||
for i in range(num_roles or 0)
|
||||
)
|
||||
role_assignment_generation_prompt = TextPrompt(
|
||||
"You are a role assignment agent, and you're in charge of "
|
||||
+ "recruiting {num_roles} experts for the following task."
|
||||
+ "\n==== TASK =====\n {task}\n\n"
|
||||
+ "Identify the domain experts you'd recruit and detail their "
|
||||
+ "associated competencies, characteristics, duties and workflows "
|
||||
+ "to complete the task.\n "
|
||||
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
||||
+ "ONLY answer the BLANKs.\n"
|
||||
+ expert_prompt
|
||||
)
|
||||
role_assignment_generation = role_assignment_generation_prompt.format(
|
||||
num_roles=num_roles, task=task_prompt
|
||||
)
|
||||
|
||||
role_assignment_generation_msg = BaseMessage.make_user_message(
|
||||
role_name="Role Assigner", content=role_assignment_generation
|
||||
)
|
||||
|
||||
response = self.step(input_message=role_assignment_generation_msg)
|
||||
|
||||
msg = response.msg # type: BaseMessage
|
||||
terminated = response.terminated
|
||||
|
||||
# Distribute the output completions into role names and descriptions
|
||||
role_names = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
role_descriptions = [
|
||||
desc.replace("<|", "").replace("|>", "")
|
||||
for desc in re.findall(
|
||||
r"Associated competencies, characteristics, "
|
||||
r"duties and workflows: (.+?) End.",
|
||||
msg.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
]
|
||||
|
||||
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
||||
raise RuntimeError(
|
||||
"Got None or insufficient information of roles."
|
||||
)
|
||||
if terminated:
|
||||
raise RuntimeError("Role assignment failed.")
|
||||
|
||||
role_descriptions_dict = {
|
||||
role_name: description
|
||||
for role_name, description in zip(role_names, role_descriptions)
|
||||
}
|
||||
|
||||
return role_descriptions_dict
|
||||
133
camel/agents/search_agent.py
Normal file
133
camel/agents/search_agent.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import TextPrompt
|
||||
from camel.types import RoleType
|
||||
from camel.utils import create_chunks
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="SearchAgent")
|
||||
class SearchAgent(ChatAgent):
|
||||
r"""An agent that summarizes text based on a query and evaluates the
|
||||
relevance of an answer.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
) -> None:
|
||||
system_message = BaseMessage(
|
||||
role_name="Assistant",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful assistant.",
|
||||
)
|
||||
super().__init__(system_message, model=model)
|
||||
|
||||
def summarize_text(self, text: str, query: str) -> str:
|
||||
r"""Summarize the information from the text, base on the query.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize.
|
||||
query (str): What information you want.
|
||||
|
||||
Returns:
|
||||
str: Strings with information.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
summary_prompt = TextPrompt(
|
||||
'''Gather information from this text that relative to the
|
||||
question, but do not directly answer the question.\nquestion:
|
||||
{query}\ntext '''
|
||||
)
|
||||
summary_prompt = summary_prompt.format(query=query)
|
||||
# Max length of each chunk
|
||||
max_len = 3000
|
||||
results = ""
|
||||
chunks = create_chunks(text, max_len)
|
||||
# Summarize
|
||||
for i, chunk in enumerate(chunks, start=1):
|
||||
prompt = summary_prompt + str(i) + ": " + chunk
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
result = self.step(user_msg).msg.content
|
||||
results += result + "\n"
|
||||
|
||||
# Final summarization
|
||||
final_prompt = TextPrompt(
|
||||
'''Here are some summarized texts which split from one text. Using
|
||||
the information to answer the question. If can't find the answer,
|
||||
you must answer "I can not find the answer to the query" and
|
||||
explain why.\n Query:\n{query}.\n\nText:\n'''
|
||||
)
|
||||
final_prompt = final_prompt.format(query=query)
|
||||
prompt = final_prompt + results
|
||||
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
|
||||
return response
|
||||
|
||||
def continue_search(self, query: str, answer: str) -> bool:
|
||||
r"""Ask whether to continue search or not based on the provided answer.
|
||||
|
||||
Args:
|
||||
query (str): The question.
|
||||
answer (str): The answer to the question.
|
||||
|
||||
Returns:
|
||||
bool: `True` if the user want to continue search, `False`
|
||||
otherwise.
|
||||
"""
|
||||
prompt = TextPrompt(
|
||||
"Do you think the ANSWER can answer the QUERY? "
|
||||
"Use only 'yes' or 'no' to answer.\n"
|
||||
"===== QUERY =====\n{query}\n\n"
|
||||
"===== ANSWER =====\n{answer}"
|
||||
)
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
user_msg = BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=prompt,
|
||||
)
|
||||
response = self.step(user_msg).msg.content
|
||||
if "yes" in str(response).lower():
|
||||
return False
|
||||
return True
|
||||
410
camel/agents/task_agent.py
Normal file
410
camel/agents/task_agent.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from camel.agents.chat_agent import ChatAgent
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend
|
||||
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
||||
from camel.types import RoleType, TaskType
|
||||
from camel.utils import get_task_list
|
||||
|
||||
# AgentOps decorator setting
|
||||
try:
|
||||
import os
|
||||
|
||||
if os.getenv("AGENTOPS_API_KEY") is not None:
|
||||
from agentops import track_agent
|
||||
else:
|
||||
raise ImportError
|
||||
except (ImportError, AttributeError):
|
||||
from camel.utils import track_agent
|
||||
|
||||
|
||||
@track_agent(name="TaskSpecifyAgent")
|
||||
class TaskSpecifyAgent(ChatAgent):
|
||||
r"""An agent that specifies a given task prompt by prompting the user to
|
||||
provide more details.
|
||||
|
||||
Attributes:
|
||||
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
||||
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
task_type (TaskType, optional): The type of task for which to generate
|
||||
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
||||
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
||||
specifying the task. (default: :obj:`None`)
|
||||
word_limit (int, optional): The word limit for the task prompt.
|
||||
(default: :obj:`50`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
DEFAULT_WORD_LIMIT = 50
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
task_type: TaskType = TaskType.AI_SOCIETY,
|
||||
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
||||
word_limit: int = DEFAULT_WORD_LIMIT,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_specify_prompt: Union[str, TextPrompt]
|
||||
if task_specify_prompt is None:
|
||||
task_specify_prompt_template = (
|
||||
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
||||
)
|
||||
|
||||
self.task_specify_prompt = task_specify_prompt_template.format(
|
||||
word_limit=word_limit
|
||||
)
|
||||
else:
|
||||
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Specifier",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You can make a task more specific.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
meta_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> TextPrompt:
|
||||
r"""Specify the given task prompt by providing more details.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The original task
|
||||
prompt.
|
||||
meta_dict (Dict[str, Any], optional): A dictionary containing
|
||||
additional information to include in the prompt.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
TextPrompt: The specified task prompt.
|
||||
"""
|
||||
self.reset()
|
||||
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
||||
|
||||
if meta_dict is not None:
|
||||
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Specifier", content=task_specify_prompt
|
||||
)
|
||||
specifier_response = self.step(task_msg)
|
||||
|
||||
if specifier_response.terminated:
|
||||
raise RuntimeError("Task specification failed.")
|
||||
if len(specifier_response.msgs) == 0:
|
||||
raise RuntimeError("Got no specification message.")
|
||||
|
||||
specified_task_msg = specifier_response.msgs[0]
|
||||
|
||||
return TextPrompt(specified_task_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPlannerAgent")
|
||||
class TaskPlannerAgent(ChatAgent):
|
||||
r"""An agent that helps divide a task into subtasks based on the input
|
||||
task prompt.
|
||||
|
||||
Attributes:
|
||||
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
||||
the task into subtasks.
|
||||
|
||||
Args:
|
||||
model (BaseModelBackend, optional): The model backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
) -> None:
|
||||
self.task_planner_prompt = TextPrompt(
|
||||
"Divide this task into subtasks: {task}. Be concise."
|
||||
)
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Planner",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task planner.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_prompt: Union[str, TextPrompt],
|
||||
) -> TextPrompt:
|
||||
r"""Generate subtasks based on the input task prompt.
|
||||
|
||||
Args:
|
||||
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
||||
be divided into subtasks.
|
||||
|
||||
Returns:
|
||||
TextPrompt: A prompt for the subtasks generated by the agent.
|
||||
"""
|
||||
# TODO: Maybe include roles information.
|
||||
self.reset()
|
||||
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Planner", content=task_planner_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task planning failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task planning message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return TextPrompt(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskCreationAgent")
|
||||
class TaskCreationAgent(ChatAgent):
|
||||
r"""An agent that helps create new tasks based on the objective
|
||||
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
||||
it's still a task planner, but it has more context information
|
||||
like last task and incomplete task list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_creation_prompt (TextPrompt): A prompt for the agent to
|
||||
create new tasks.
|
||||
|
||||
Args:
|
||||
role_name (str): The role name of the Agent to create the task.
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
max_task_num (int, optional): The maximum number of planned
|
||||
tasks in one round. (default: :obj:3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role_name: str,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
max_task_num: Optional[int] = 3,
|
||||
) -> None:
|
||||
task_creation_prompt = TextPrompt(
|
||||
"""Create new a task with the following objective: {objective}.
|
||||
Never forget you are a Task Creator of {role_name}.
|
||||
You must instruct me based on my expertise and your needs to solve the task.
|
||||
You should consider past solved tasks and in-progress tasks: {task_list}.
|
||||
The new created tasks must not overlap with these past tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First Task
|
||||
#. Second Task
|
||||
#. Third Task
|
||||
|
||||
You can only give me up to {max_task_num} tasks at a time. \
|
||||
Each task should be concise, concrete and doable for a {role_name}.
|
||||
You should make task plan and not ask me questions.
|
||||
If you think no new tasks are needed right now, write "No tasks to add."
|
||||
Now start to give me new tasks one by one. No more than three tasks.
|
||||
Be concrete.
|
||||
"""
|
||||
)
|
||||
|
||||
self.task_creation_prompt = task_creation_prompt.format(
|
||||
objective=objective, role_name=role_name, max_task_num=max_task_num
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Creator",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task creator.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Generate subtasks based on the previous task results and
|
||||
incomplete task list.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The completed or in-progress
|
||||
tasks which should not overlap with new created tasks.
|
||||
|
||||
Returns:
|
||||
List[str]: The new task list generated by the Agent.
|
||||
"""
|
||||
|
||||
if len(task_list) > 0:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
else:
|
||||
task_creation_prompt = self.task_creation_prompt.format(
|
||||
task_list=""
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Creator", content=task_creation_prompt
|
||||
)
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task creation failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task creation message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
|
||||
|
||||
@track_agent(name="TaskPrioritizationAgent")
|
||||
class TaskPrioritizationAgent(ChatAgent):
|
||||
r"""An agent that helps re-prioritize the task list and
|
||||
returns numbered prioritized list. Modified from
|
||||
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
||||
|
||||
Attributes:
|
||||
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
||||
prioritize tasks.
|
||||
|
||||
Args:
|
||||
objective (Union[str, TextPrompt]): The objective of the Agent to
|
||||
perform the task.
|
||||
model (BaseModelBackend, optional): The LLM backend to use for
|
||||
generating responses. (default: :obj:`OpenAIModel` with
|
||||
`GPT_4O_MINI`)
|
||||
output_language (str, optional): The language to be output by the
|
||||
agent. (default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
objective: Union[str, TextPrompt],
|
||||
model: Optional[BaseModelBackend] = None,
|
||||
output_language: Optional[str] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
task_prioritization_prompt = TextPrompt(
|
||||
"""Prioritize the following tasks : {task_list}.
|
||||
Consider the ultimate objective of you: {objective}.
|
||||
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
||||
tasks are those that act as pre-requisites or are more essential for meeting \
|
||||
the objective. Return one task per line in your response.
|
||||
Do not remove or modify any tasks.
|
||||
The result must be a numbered list in the format:
|
||||
|
||||
#. First task
|
||||
#. Second task
|
||||
|
||||
The entries must be consecutively numbered, starting with 1.
|
||||
The number of each entry must be followed by a period.
|
||||
Do not include any headers before your ranked list or follow your list \
|
||||
with any other output."""
|
||||
)
|
||||
|
||||
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
||||
objective=objective
|
||||
)
|
||||
self.objective = objective
|
||||
|
||||
system_message = BaseMessage(
|
||||
role_name="Task Prioritizer",
|
||||
role_type=RoleType.ASSISTANT,
|
||||
meta_dict=None,
|
||||
content="You are a helpful task prioritizer.",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
system_message,
|
||||
model=model,
|
||||
output_language=output_language,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task_list: List[str],
|
||||
) -> List[str]:
|
||||
r"""Prioritize the task list given the agent objective.
|
||||
|
||||
Args:
|
||||
task_list (List[str]): The unprioritized tasks of agent.
|
||||
|
||||
Returns:
|
||||
List[str]: The new prioritized task list generated by the Agent.
|
||||
"""
|
||||
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
||||
task_list=task_list
|
||||
)
|
||||
|
||||
task_msg = BaseMessage.make_user_message(
|
||||
role_name="Task Prioritizer", content=task_prioritization_prompt
|
||||
)
|
||||
|
||||
task_response = self.step(task_msg)
|
||||
|
||||
if task_response.terminated:
|
||||
raise RuntimeError("Task prioritization failed.")
|
||||
if len(task_response.msgs) == 0:
|
||||
raise RuntimeError("Got no task prioritization message.")
|
||||
|
||||
sub_tasks_msg = task_response.msgs[0]
|
||||
return get_task_list(sub_tasks_msg.content)
|
||||
20
camel/agents/tool_agents/__init__.py
Normal file
20
camel/agents/tool_agents/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base import BaseToolAgent
|
||||
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
]
|
||||
39
camel/agents/tool_agents/base.py
Normal file
39
camel/agents/tool_agents/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from camel.agents import BaseAgent
|
||||
|
||||
|
||||
class BaseToolAgent(BaseAgent):
|
||||
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
||||
description.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool agent.
|
||||
description (str): The description of the tool agent.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the agent to its initial state."""
|
||||
pass
|
||||
|
||||
def step(self) -> None:
|
||||
r"""Performs a single step of the agent."""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}"
|
||||
206
camel/agents/tool_agents/hugging_face_tool_agent.py
Normal file
206
camel/agents/tool_agents/hugging_face_tool_agent.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Optional
|
||||
|
||||
from camel.agents.tool_agents.base import BaseToolAgent
|
||||
|
||||
|
||||
# flake8: noqa :E501
|
||||
class HuggingFaceToolAgent(BaseToolAgent):
|
||||
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
||||
around agents from the `transformers` library. For more information
|
||||
about the available models, please see the `transformers` documentation
|
||||
at https://huggingface.co/docs/transformers/transformers_agents.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
*args (Any): Additional positional arguments to pass to the underlying
|
||||
Agent class.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. (default: :obj:`True`)
|
||||
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
||||
Agent class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*args: Any,
|
||||
remote: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
# TODO: Support other tool agents
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
if version.parse(transformers.__version__) < version.parse(
|
||||
"4.31.0"
|
||||
):
|
||||
raise ValueError(
|
||||
"The version of \"transformers\" package should >= 4.31.0"
|
||||
)
|
||||
|
||||
from transformers.tools import OpenAiAgent
|
||||
from transformers.tools.agent_types import AgentImage
|
||||
except (ImportError, ValueError):
|
||||
raise ValueError(
|
||||
"Could not import transformers tool agents. "
|
||||
"Please setup the environment with "
|
||||
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
||||
)
|
||||
self.agent_image_type = AgentImage
|
||||
self.agent = OpenAiAgent(*args, **kwargs)
|
||||
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
||||
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
||||
- Text question answering: given a long text and a question, answer the question in the text
|
||||
- Unconditional image captioning: Caption the image!
|
||||
- Image question answering: given an image, answer a question on this image
|
||||
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
||||
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
||||
- Text to speech: convert text to speech
|
||||
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
||||
- Text summarization: summarize a long text in one or a few sentences
|
||||
- Translation: translate the text into a given language
|
||||
- Text downloading: to download a text from a web URL
|
||||
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
||||
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
- Text to video: generate a small video according to a prompt
|
||||
|
||||
Here are some python code examples of what you can do with this agent:
|
||||
|
||||
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
||||
```
|
||||
# Text to image
|
||||
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
||||
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
||||
|
||||
# Text to image -> Image transformation
|
||||
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
||||
sea_add_island_image.save("./sea_add_island_image.png")
|
||||
|
||||
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
||||
# you can do so by specifying variables that you would like the agent to use. For example,
|
||||
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
||||
picture = {name}.step("Generate a picture of rivers and lakes.")
|
||||
picture.save("./picture.png")
|
||||
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
||||
updated_picture.save("./updated_picture.png")
|
||||
|
||||
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
capybara_sea_image.save("./capybara_sea_image.png")
|
||||
|
||||
# Document question answering
|
||||
answer = {name}.step(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
print(answer)
|
||||
|
||||
|
||||
# Text to image
|
||||
boat_image = {name}.step("Generate an image of a boat in the water")
|
||||
boat_image.save("./boat_image.png")
|
||||
|
||||
# Unconditional image captioning
|
||||
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
||||
print(boat_image_caption)
|
||||
|
||||
# Text to image -> Unconditional image captioning -> Text to speech
|
||||
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
||||
|
||||
# Text downloading
|
||||
document = {name}.step("Download the text from http://hf.co")
|
||||
print(document)
|
||||
|
||||
# Text summarization
|
||||
summary = {name}.step("Summarize the following text: `document`", document=document)
|
||||
print(summary)
|
||||
|
||||
# Text downloading -> Text summarization -> Text to speech
|
||||
audio = {name}.step("Read out loud the summary of http://hf.co")
|
||||
```
|
||||
|
||||
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
||||
```
|
||||
# Clean the chat history
|
||||
{name}.reset()
|
||||
|
||||
# Text to image
|
||||
capybara_image = {name}.chat("Show me an an image of a capybara")
|
||||
capybara_image.save("./capybara_image.png")
|
||||
|
||||
# Image transformation
|
||||
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
||||
transformed_capybara_image.save("./transformed_capybara_image.png")
|
||||
|
||||
# Image segmentation
|
||||
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
||||
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
||||
```
|
||||
"""
|
||||
super(HuggingFaceToolAgent, self).__init__(name, description)
|
||||
self.remote = remote
|
||||
|
||||
def reset(self) -> None:
|
||||
r"""Resets the chat history of the agent."""
|
||||
self.agent.prepare_for_new_chat()
|
||||
|
||||
def step(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in single execution mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
*args: Any,
|
||||
remote: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
r"""Runs the agent in a chat conversation mode.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to the agent.
|
||||
remote (bool, optional): Flag indicating whether to run the agent
|
||||
remotely. Overrides the default setting. (default: :obj:`None`)
|
||||
**kwargs (Any): Keyword arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
str: The response from the agent.
|
||||
"""
|
||||
if remote is None:
|
||||
remote = self.remote
|
||||
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
||||
if isinstance(agent_output, self.agent_image_type):
|
||||
agent_output = agent_output.to_raw()
|
||||
return agent_output
|
||||
30
camel/benchmarks/__init__.py
Normal file
30
camel/benchmarks/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .apibank import APIBankBenchmark
|
||||
from .apibench import APIBenchBenchmark
|
||||
from .base import BaseBenchmark
|
||||
from .gaia import DefaultGAIARetriever, GAIABenchmark
|
||||
from .nexus import NexusBenchmark
|
||||
from .ragbench import RAGBenchBenchmark
|
||||
|
||||
__all__ = [
|
||||
"BaseBenchmark",
|
||||
"GAIABenchmark",
|
||||
"DefaultGAIARetriever",
|
||||
"NexusBenchmark",
|
||||
"APIBenchBenchmark",
|
||||
"APIBankBenchmark",
|
||||
"RAGBenchBenchmark",
|
||||
]
|
||||
571
camel/benchmarks/apibank.py
Normal file
571
camel/benchmarks/apibank.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from rouge import Rouge
|
||||
from tqdm import tqdm
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.benchmarks.base import BaseBenchmark
|
||||
from camel.messages import BaseMessage
|
||||
from camel.utils import download_github_subdirectory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add current folder to sys.path to enable relative import
|
||||
current_folder = os.getcwd()
|
||||
if current_folder not in sys.path:
|
||||
sys.path.append(current_folder)
|
||||
|
||||
|
||||
def process_messages(
|
||||
chat_history: List[Dict[str, Any]],
|
||||
prompt: str,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Processes chat history into a structured format for further use.
|
||||
|
||||
Args:
|
||||
chat_history (List[Dict[str, Any]):
|
||||
A list of dictionaries representing the chat history.
|
||||
prompt (str): A prompt to be set as the system message.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries representing
|
||||
the processed messages, where each dictionary has:
|
||||
- 'role': The role of the message ('system', 'user', or 'assistant').
|
||||
- 'content': The content of the message, including formatted
|
||||
API responses when applicable.
|
||||
"""
|
||||
messages = [{'role': 'system', 'content': prompt}]
|
||||
for item in chat_history:
|
||||
role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
|
||||
chat_role = role_map.get(
|
||||
item['role'], 'unknown'
|
||||
) # default role to 'unknown'
|
||||
if item['role'] == 'API':
|
||||
chat_content = '[{}({})] Response: {}'.format(
|
||||
item['api_name'],
|
||||
', '.join(
|
||||
[
|
||||
'{}=\'{}\''.format(k, v)
|
||||
for k, v in item['param_dict'].items()
|
||||
]
|
||||
),
|
||||
str(item['result']['output']),
|
||||
)
|
||||
else:
|
||||
chat_content = item['text']
|
||||
messages.append({'role': chat_role, 'content': chat_content})
|
||||
return messages
|
||||
|
||||
|
||||
class APIBankBenchmark(BaseBenchmark):
|
||||
r"""API-Bank Benchmark adapted from `API-Bank:
|
||||
A Comprehensive Benchmark for Tool-Augmented LLMs`
|
||||
<https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
|
||||
|
||||
Args:
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use.
|
||||
(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_to: str,
|
||||
processes: int = 1,
|
||||
):
|
||||
r"""Initialize the APIBank benchmark.
|
||||
|
||||
Args:
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use for
|
||||
parallel processing. (default: :obj:`1`)
|
||||
"""
|
||||
# Predefine data_dir for better import management
|
||||
super().__init__("apibank", "api_bank", save_to, processes)
|
||||
self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
|
||||
|
||||
def download(self):
|
||||
r"""Download APIBank dataset and code from Github."""
|
||||
|
||||
repo = "AlibabaResearch/DAMO-ConvAI"
|
||||
subdir = "api-bank"
|
||||
data_dir = self.data_dir
|
||||
|
||||
download_github_subdirectory(repo, subdir, data_dir)
|
||||
|
||||
sys.path.insert(0, self.data_dir)
|
||||
logger.info("Download completed.")
|
||||
|
||||
def load(self, level: str, force_download: bool = False): # type: ignore[override]
|
||||
r"""Load the APIBank Benchmark dataset.
|
||||
|
||||
Args:
|
||||
level (str): Level to run benchmark on.
|
||||
force_download (bool, optional): Whether to
|
||||
force download the data.
|
||||
"""
|
||||
if force_download:
|
||||
logger.info("Force downloading data.")
|
||||
self.download()
|
||||
|
||||
if level == "level-1":
|
||||
file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
|
||||
elif level == 'level-2':
|
||||
file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
|
||||
jsonl_files = [
|
||||
f for f in os.listdir(file_path) if f.endswith('.jsonl')
|
||||
]
|
||||
for file in tqdm(jsonl_files, desc="Processing files"):
|
||||
history = []
|
||||
with open(file_path / file, 'r') as f:
|
||||
for line in f:
|
||||
history.append(json.loads(line))
|
||||
samples = APIBankSample.from_chat_history(history)
|
||||
self._data[file.rsplit('.', 1)[0]] = samples
|
||||
|
||||
# Change import to relative import in the downloaded python files
|
||||
def process_files(folder_path, replacements):
|
||||
r"""Replace absolute imports in downloaded files with
|
||||
relative import."""
|
||||
for file in os.listdir(folder_path):
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(folder_path, file)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
|
||||
original_content = content
|
||||
|
||||
for pattern, replacement in replacements:
|
||||
content = re.sub(pattern, replacement, content)
|
||||
|
||||
if content != original_content:
|
||||
with open(
|
||||
file_path, "w", encoding="utf-8"
|
||||
) as file:
|
||||
file.write(content)
|
||||
logger.info(f"Updated file: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing file {file_path}: {e}")
|
||||
|
||||
api_bank_folder = "api_bank"
|
||||
apis_folder = os.path.join(api_bank_folder, "apis")
|
||||
|
||||
apis_replacements = [
|
||||
(r"from apis.api", "from .api"),
|
||||
(r"from apis import", "from .api import"),
|
||||
]
|
||||
|
||||
api_bank_replacements = [
|
||||
(r"from apis", "from .apis"),
|
||||
(r"from api_call_extraction", "from .api_call_extraction"),
|
||||
(r"f'{basename}", r"f'api_bank.{basename}"),
|
||||
]
|
||||
|
||||
process_files(apis_folder, apis_replacements)
|
||||
process_files(api_bank_folder, api_bank_replacements)
|
||||
|
||||
def run( # type: ignore[override, return]
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
level: Literal["level-1", "level-2"],
|
||||
api_test_enabled=True,
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent to run the
|
||||
benchmark.
|
||||
level (Literal['level-1', 'level-2']):
|
||||
The level to run the benchmark on.
|
||||
randomize (bool, optional): Whether to
|
||||
randomize the data.
|
||||
api_test_enabled (bool): Whether to test
|
||||
API calling (`True`) or response (`False`)
|
||||
(default: :obj:`False`)
|
||||
subset (Optional[int], optional):
|
||||
The subset of data to run.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The results of the benchmark.
|
||||
"""
|
||||
logger.info(f"Running APIBench benchmark on {level}.")
|
||||
self.load(level)
|
||||
datas = self._data
|
||||
|
||||
# Shuffle and subset data if necessary
|
||||
if randomize:
|
||||
randomized_items = list(datas.items())
|
||||
random.shuffle(randomized_items)
|
||||
datas = dict(randomized_items)
|
||||
if subset:
|
||||
datas = dict(list(datas.items())[:subset])
|
||||
|
||||
logger.info(f"Number of tasks: {len(datas)}")
|
||||
|
||||
# Initialize results storage
|
||||
self._results = []
|
||||
|
||||
# The following code are adapted from the evaluator
|
||||
# from the original repo:
|
||||
tool_search_enabled = level == "level-2"
|
||||
dialog_test_enabled = not api_test_enabled
|
||||
total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
|
||||
|
||||
with open(self.save_to, "w") as f:
|
||||
for test in tqdm(datas, desc="Running"):
|
||||
samples = self._data[test]
|
||||
evaluator = Evaluator(samples) # type: ignore[arg-type]
|
||||
|
||||
for sample_id in evaluator.get_all_sample_ids():
|
||||
# Process sample and generate response
|
||||
sample = evaluator.dataset[sample_id]
|
||||
|
||||
if (
|
||||
sample.ground_truth['role'] == 'API'
|
||||
and api_test_enabled
|
||||
):
|
||||
if tool_search_enabled:
|
||||
_, chat_history = evaluator.get_model_input(
|
||||
sample_id
|
||||
)
|
||||
api_descriptions = evaluator.get_api_description(
|
||||
'ToolSearcher'
|
||||
)
|
||||
else:
|
||||
api_descriptions, chat_history = (
|
||||
evaluator.get_model_input(sample_id)
|
||||
)
|
||||
messages = process_messages(
|
||||
chat_history, API_CALL_PROMPT + api_descriptions
|
||||
)
|
||||
model_output = agent_call(messages, agent)
|
||||
api_call = get_api_call(model_output)
|
||||
|
||||
# Evaluate API call
|
||||
if api_call:
|
||||
try:
|
||||
correct, model_output_result = (
|
||||
evaluator.evaluate(sample_id, api_call)
|
||||
)
|
||||
except AssertionError as e:
|
||||
if 'The API name is not correct.' not in str(
|
||||
e
|
||||
):
|
||||
raise e
|
||||
logging.info('AssertionError: {}'.format(e))
|
||||
correct = False
|
||||
else:
|
||||
model_output_result = 'No API call found'
|
||||
correct = False
|
||||
if correct:
|
||||
correct_api_calls += 1
|
||||
logging.info(
|
||||
'Correct API call: {} Ground truth: {}'.format(
|
||||
api_call, sample.ground_truth
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
'Incorrect model output: {} Result: {} \
|
||||
Ground truth: {} File: {} Sample ID: {} \
|
||||
Messages: {}'.format(
|
||||
model_output.replace('\n', ' '),
|
||||
model_output_result,
|
||||
sample.ground_truth,
|
||||
test,
|
||||
sample_id,
|
||||
messages[1:],
|
||||
)
|
||||
)
|
||||
total_api_calls += 1
|
||||
self._results.append(
|
||||
{
|
||||
'Role': 'API',
|
||||
'Model_output': model_output,
|
||||
'Model_output_result': model_output_result,
|
||||
'Ground_truth': sample.ground_truth,
|
||||
'Test': test,
|
||||
'Correct': correct,
|
||||
}
|
||||
)
|
||||
json_str = json.dumps(
|
||||
self._results[-1], indent=2, ensure_ascii=False
|
||||
)
|
||||
f.write(json_str + "\n")
|
||||
|
||||
elif (
|
||||
sample.ground_truth['role'] == 'AI'
|
||||
and dialog_test_enabled
|
||||
):
|
||||
# Process sample and generate response
|
||||
api_descriptions, chat_history = (
|
||||
evaluator.get_model_input(sample_id)
|
||||
)
|
||||
|
||||
messages = process_messages(
|
||||
chat_history, RESPONSE_PROMPT + api_descriptions
|
||||
)
|
||||
model_output = agent_call(messages, agent)
|
||||
|
||||
# Evaluate model response
|
||||
if model_output:
|
||||
score = evaluator.evaluate(sample_id, model_output)
|
||||
else:
|
||||
score = 0
|
||||
rougel_scores.append(score)
|
||||
if score < 0.2:
|
||||
logging.info(
|
||||
'Low score: {} Score: {} Ground truth: {} \
|
||||
Test: {} Sample ID: {} \
|
||||
Messages: {}'.format(
|
||||
model_output.replace('\n', ' '),
|
||||
score,
|
||||
sample.ground_truth,
|
||||
test,
|
||||
sample_id,
|
||||
messages[1:],
|
||||
)
|
||||
)
|
||||
|
||||
self._results.append(
|
||||
{
|
||||
'Role': 'AI',
|
||||
'Model_output': model_output,
|
||||
'Score': score,
|
||||
'Ground_truth': sample.ground_truth,
|
||||
'Test': test,
|
||||
}
|
||||
)
|
||||
json_str = json.dumps(
|
||||
self._results[-1], indent=2, ensure_ascii=False
|
||||
)
|
||||
f.write(json_str + "\n")
|
||||
|
||||
f.flush()
|
||||
|
||||
if api_test_enabled:
|
||||
return {
|
||||
'total': total_api_calls,
|
||||
'correct': correct_api_calls,
|
||||
"accuracy": correct_api_calls / total_api_calls
|
||||
if total_api_calls
|
||||
else 0,
|
||||
}
|
||||
elif dialog_test_enabled:
|
||||
return {'Dialog_score': np.mean(rougel_scores)}
|
||||
|
||||
|
||||
# The following code are migrated from the original repo:
|
||||
# https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
|
||||
def agent_call(messages: List[Dict], agent: ChatAgent):
|
||||
r"""Add messages to agent memory and get response."""
|
||||
for i, msg in enumerate(messages):
|
||||
if msg['role'] == 'user':
|
||||
message = BaseMessage.make_user_message(
|
||||
role_name="CAMEL User", content=msg['content']
|
||||
)
|
||||
elif msg['role'] == 'assistant':
|
||||
message = BaseMessage.make_assistant_message(
|
||||
role_name="CAMEL Assistant", content=msg['content']
|
||||
)
|
||||
elif msg['role'] == 'system':
|
||||
message = BaseMessage.make_assistant_message(
|
||||
role_name="System", content=msg['content']
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized role: {msg['role']}")
|
||||
|
||||
if i == len(messages) - 1:
|
||||
break
|
||||
agent.record_message(message)
|
||||
|
||||
response = agent.step(message)
|
||||
model_output = response.msgs[0].content
|
||||
agent.reset()
|
||||
return model_output
|
||||
|
||||
|
||||
def calculate_rouge_l_score(reference, hypothesis):
|
||||
r"""Calculate rouge l score between hypothesis and reference."""
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(hypothesis, reference)
|
||||
rouge_l_score = scores[0]['rouge-l']['f']
|
||||
return rouge_l_score
|
||||
|
||||
|
||||
def get_api_call(model_output):
|
||||
r"""Parse api call from model output."""
|
||||
api_call_pattern = r"\[(\w+)\((.*)\)\]"
|
||||
api_call_pattern = re.compile(api_call_pattern)
|
||||
match = api_call_pattern.search(model_output)
|
||||
if match:
|
||||
return match.group(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class APIBankSample:
|
||||
r"""APIBank sample used to load the datasets."""
|
||||
|
||||
def __init__(self, chat_history, apis, ground_truth):
|
||||
self.chat_history = chat_history
|
||||
self.apis = apis
|
||||
self.ground_truth = ground_truth
|
||||
|
||||
def __repr__(self):
|
||||
return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
|
||||
self.chat_history, self.apis, self.ground_truth
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chat_history(cls, chat_history):
|
||||
apis = set()
|
||||
api_positions = []
|
||||
for i, item in enumerate(chat_history):
|
||||
if item['role'] == 'API':
|
||||
apis.add(item['api_name'])
|
||||
api_positions.append(i)
|
||||
|
||||
samples = []
|
||||
for i in api_positions:
|
||||
sample = cls(chat_history[:i], apis, chat_history[i])
|
||||
samples.append(sample)
|
||||
sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
|
||||
samples.append(sample)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
class Evaluator:
|
||||
r"""Evaluator for APIBank benchmark."""
|
||||
|
||||
def __init__(self, samples: List[APIBankSample]):
|
||||
# Place holder for import as the import
|
||||
# only works after the files have been downloaded
|
||||
try:
|
||||
from api_bank.tool_manager import ( # type: ignore[import-not-found]
|
||||
ToolManager,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"{e}, Module will be imported after download.")
|
||||
self.dataset = samples
|
||||
self.sample_ids = list(range(len(self.dataset)))
|
||||
os.chdir("api_bank")
|
||||
self.tool_manager = ToolManager("apis")
|
||||
os.chdir("..")
|
||||
|
||||
def get_all_sample_ids(self):
|
||||
return self.sample_ids
|
||||
|
||||
def get_api_description(self, api_name):
|
||||
return self.tool_manager.get_api_description(api_name)
|
||||
|
||||
def get_model_input(self, sample_id: int):
|
||||
sample = self.dataset[sample_id]
|
||||
apis = sample.apis
|
||||
chat_history = sample.chat_history
|
||||
api_descriptions = []
|
||||
for api_name in apis:
|
||||
api_descriptions.append(
|
||||
self.tool_manager.get_api_description(api_name)
|
||||
)
|
||||
api_description = '\n'.join(api_descriptions)
|
||||
return api_description, chat_history
|
||||
|
||||
def evaluate(self, sample_id, model_output):
|
||||
try:
|
||||
from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
|
||||
parse_api_call,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"{e}, Module will be imported after download.")
|
||||
sample = self.dataset[sample_id]
|
||||
ground_truth = sample.ground_truth
|
||||
if ground_truth['role'] == 'API':
|
||||
api_name, param_dict = parse_api_call(model_output)
|
||||
if api_name != ground_truth['api_name']:
|
||||
return False, 'API Name Mismatch: {} vs {}'.format(
|
||||
api_name, ground_truth['api_name']
|
||||
)
|
||||
try:
|
||||
result = self.tool_manager.api_call(api_name, **param_dict)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
api = self.tool_manager.init_tool(api_name)
|
||||
try:
|
||||
correct = api.check_api_call_correctness(
|
||||
result, ground_truth['result']
|
||||
)
|
||||
except KeyError:
|
||||
correct = False
|
||||
result = 'KeyError' + str(result)
|
||||
return correct, result
|
||||
elif ground_truth['role'] == 'AI':
|
||||
score = calculate_rouge_l_score(ground_truth['text'], model_output)
|
||||
return round(score, 4)
|
||||
|
||||
|
||||
API_CALL_PROMPT = '''
|
||||
Based on the given API description and the existing \
|
||||
conversation history 1..t, please generate the API request \
|
||||
that the AI should call in step t+1 and output it in the \
|
||||
format of [ApiName(key1='value1', key2='value2', ...)], \
|
||||
replace the ApiName with the actual API name, and \
|
||||
replace the key and value with the actual parameters. \
|
||||
Your output should start with a square bracket "[" \
|
||||
and end with a square bracket "]". Do not output any \
|
||||
other explanation or prompt or the result of the API call in your output.
|
||||
This year is 2023.
|
||||
Input:
|
||||
User: [User's utterence]
|
||||
AI: [AI's utterence]
|
||||
|
||||
Expected output:
|
||||
[ApiName(key1='value1', key2='value2', ...)]
|
||||
|
||||
API descriptions:
|
||||
'''
|
||||
|
||||
RESPONSE_PROMPT = '''
|
||||
Based on the given API description and the existing \
|
||||
conversation history 1..t, please generate the next \
|
||||
dialog that the AI should response after the API call t.
|
||||
This year is 2023.
|
||||
Input:
|
||||
User: [User's utterence]
|
||||
AI: [AI's utterence]
|
||||
[ApiName(key1='value1', key2='value2', …)]
|
||||
|
||||
Expected output:
|
||||
AI: [AI's utterence]
|
||||
|
||||
API descriptions:
|
||||
'''
|
||||
499
camel/benchmarks/apibench.py
Normal file
499
camel/benchmarks/apibench.py
Normal file
@@ -0,0 +1,499 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
import tree_sitter_python as tspython
|
||||
from tqdm import tqdm
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.benchmarks.base import BaseBenchmark
|
||||
from camel.utils import download_github_subdirectory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of dataset names to file names
|
||||
# 'Oracle' retriever used here which means all the full
|
||||
# API documentation will be included in the prompt
|
||||
dataset_mapping = {
|
||||
"huggingface": {
|
||||
"api": "huggingface_api.jsonl",
|
||||
"eval": "huggingface_eval.json",
|
||||
"train": "huggingface_train.json",
|
||||
"questions": "questions_huggingface_oracle.jsonl",
|
||||
},
|
||||
"tensorflowhub": {
|
||||
"api": "tensorflowhub_api.jsonl",
|
||||
"eval": "tensorflow_eval.json",
|
||||
"train": "tensorflow_train.json",
|
||||
"questions": "questions_tensorflowhub_oracle.jsonl",
|
||||
},
|
||||
"torchhub": {
|
||||
"api": "torchhub_api.jsonl",
|
||||
"eval": "torchhub_eval.json",
|
||||
"train": "torchhub_train.json",
|
||||
"questions": "questions_torchhub_oracle.jsonl",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# This function is migrated from the original repo:
|
||||
# https://github.com/ShishirPatil/gorilla
|
||||
def encode_question(question: str, dataset_name: str) -> str:
|
||||
r"""Encode multiple prompt instructions into a single string."""
|
||||
|
||||
if dataset_name == "torchhub":
|
||||
domains = "1. $DOMAIN is inferred from the task description and \
|
||||
should include one of {Classification, Semantic Segmentation, \
|
||||
Object Detection, Audio Separation, Video Classification, \
|
||||
Text-to-Speech}."
|
||||
elif dataset_name == "huggingface":
|
||||
domains = "1. $DOMAIN should include one of {Multimodal Feature \
|
||||
Extraction, Multimodal Text-to-Image, Multimodal \
|
||||
Image-to-Text, Multimodal Text-to-Video, \
|
||||
Multimodal Visual Question Answering, Multimodal Document \
|
||||
Question Answer, Multimodal Graph Machine Learning, \
|
||||
Computer Vision Depth Estimation, Computer Vision Image \
|
||||
Classification, Computer Vision Object Detection, \
|
||||
Computer Vision Image Segmentation, Computer Vision \
|
||||
Image-to-Image, Computer Vision Unconditional \
|
||||
Image Generation, Computer Vision Video Classification, \
|
||||
Computer Vision Zero-Shor Image Classification, \
|
||||
Natural Language Processing Text Classification, \
|
||||
Natural Language Processing Token Classification, \
|
||||
Natural Language Processing Table Question Answering, \
|
||||
Natural Language Processing Question Answering, \
|
||||
Natural Language Processing, Zero-Shot Classification \
|
||||
Natural Language Processing Translation, Natural Language \
|
||||
Processing Summarization, Natural Language Processing \
|
||||
Conversational, Natural Language Processing Text \
|
||||
Generation, Natural Language Processing Fill-Mask, \
|
||||
Natural Language Processing Text2Text Generation, \
|
||||
Natural Language Processing Sentence Similarity, \
|
||||
Audio Text-to-Speech, Audio Automatic Speech Recognition, \
|
||||
Audio Audio-to-Audio, Audio Audio Classification, \
|
||||
Audio Voice Activity Detection, Tabular Tabular \
|
||||
Classification, Tabular Tabular Regression, \
|
||||
Reinforcement Learning Reinforcement Learning, \
|
||||
Reinforcement Learning Robotics }"
|
||||
elif dataset_name == "tensorflowhub":
|
||||
domains = "1. $DOMAIN is inferred from the task description \
|
||||
and should include one of {text-sequence-alignment, \
|
||||
text-embedding, text-language-model, text-preprocessing, \
|
||||
text-classification, text-generation, text-question-answering, \
|
||||
text-retrieval-question-answering, text-segmentation, \
|
||||
text-to-mel, image-classification, image-feature-vector, \
|
||||
image-object-detection, image-segmentation, \
|
||||
image-generator, image-pose-detection, image-rnn-agent, \
|
||||
image-augmentation, image-classifier, image-style-transfer, \
|
||||
image-aesthetic-quality, image-depth-estimation, \
|
||||
image-super-resolution, image-deblurring, image-extrapolation, \
|
||||
image-text-recognition, image-dehazing, image-deraining, \
|
||||
image-enhancemenmt, image-classification-logits, \
|
||||
image-frame-interpolation, image-text-detection, image-denoising, \
|
||||
image-others, video-classification, video-feature-extraction, \
|
||||
video-generation, video-audio-text, video-text, \
|
||||
audio-embedding, audio-event-classification, audio-command-detection, \
|
||||
audio-paralinguists-classification, audio-speech-to-text, \
|
||||
audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
|
||||
else:
|
||||
logger.info("Error: API name is not supported.")
|
||||
|
||||
prompt = (
|
||||
question
|
||||
+ "\nWrite a python program in 1 to 2 lines to call API in "
|
||||
+ dataset_name
|
||||
+ ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
|
||||
<<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
|
||||
<<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
|
||||
Here are the requirements:\n"
|
||||
+ domains
|
||||
+ "\n2. The $API_CALL should have only 1 line of code \
|
||||
that calls api.\n 3. The $API_PROVIDER should be the \
|
||||
programming framework used.\n4. $EXPLANATION should be \
|
||||
a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
|
||||
Do not repeat the format in your answer."
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
class APIBenchBenchmark(BaseBenchmark):
|
||||
r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
|
||||
Connected with Massive APIs`
|
||||
<https://huggingface.co/datasets/gorilla-llm/APIBench>.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use.
|
||||
(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
# TODO: Integrate retriever (pending)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
save_to: str,
|
||||
processes: int = 1,
|
||||
):
|
||||
r"""Initialize the APIBench benchmark.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use for
|
||||
parallel processing. (default: :obj:`1`)
|
||||
"""
|
||||
super().__init__("apibench", data_dir, save_to, processes)
|
||||
|
||||
def download(self):
|
||||
r"""Download the APIBench dataset."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="gorilla-llm/APIBench",
|
||||
repo_type="dataset",
|
||||
local_dir=self.data_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
repo = "ShishirPatil/gorilla"
|
||||
subdir = "/gorilla/eval/eval-data/questions"
|
||||
data_dir = self.data_dir
|
||||
|
||||
download_github_subdirectory(repo, subdir, data_dir)
|
||||
|
||||
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
||||
r"""Load the APIBench Benchmark dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the specific dataset to be loaded.
|
||||
force_download (bool, optional): Whether to force
|
||||
download the data. (default: :obj:`False`)
|
||||
"""
|
||||
|
||||
if force_download:
|
||||
logger.info("Force downloading data.")
|
||||
self.download()
|
||||
|
||||
def load_json_lines(file_path: Path):
|
||||
r"""Helper function to load JSON lines from a file."""
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error decoding JSON in file {file_path}: {e}"
|
||||
)
|
||||
|
||||
dataset_path = self.data_dir / dataset_name
|
||||
if not dataset_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Dataset directory does not exist: {dataset_path}"
|
||||
)
|
||||
|
||||
for label in ['api', 'eval', 'questions']:
|
||||
file_name = dataset_mapping[dataset_name][label]
|
||||
file_path = (
|
||||
dataset_path / file_name
|
||||
if label == 'questions'
|
||||
else self.data_dir / file_name
|
||||
)
|
||||
|
||||
# Load data based on label type
|
||||
if label in ['api', 'questions', 'eval']:
|
||||
data = load_json_lines(file_path)
|
||||
|
||||
if label == 'eval':
|
||||
# Extract 'api_data' specifically for eval label
|
||||
data = [item['api_data'] for item in data]
|
||||
|
||||
self._data[label] = data
|
||||
else:
|
||||
raise ValueError(f"Unknown label: {label}")
|
||||
|
||||
ast_database = []
|
||||
for data in self._data['api']:
|
||||
ast_tree = ast_parse(data['api_call'])
|
||||
ast_database.append(ast_tree)
|
||||
self._data['ast'] = ast_database
|
||||
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent to run the
|
||||
benchmark.
|
||||
dataset_name (Literal["huggingface",
|
||||
"tensorflowhub", "torchhub"]):
|
||||
The dataset to run the benchmark.
|
||||
randomize (bool, optional): Whether to randomize the data.
|
||||
(default: :obj:`False`)
|
||||
subset (Optional[int], optional): The subset of data to run.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
if dataset_name not in dataset_mapping:
|
||||
raise ValueError(f"Invalid value for dataset: {dataset_name}.")
|
||||
|
||||
logger.info(f"Running APIBench benchmark on {dataset_name}.")
|
||||
self.load(dataset_name)
|
||||
datas = self._data['questions']
|
||||
|
||||
# Shuffle and subset data if necessary
|
||||
if randomize:
|
||||
random.shuffle(datas)
|
||||
if subset:
|
||||
datas = datas[:subset]
|
||||
|
||||
logger.info(f"Number of tasks: {len(datas)}")
|
||||
|
||||
# Initialize results storage
|
||||
self._results = []
|
||||
|
||||
with open(self.save_to, "w") as f:
|
||||
for question in tqdm(datas, desc="Running"):
|
||||
prompt = encode_question(question["text"], dataset_name)
|
||||
try:
|
||||
# Generate response
|
||||
responses = agent.step(prompt)
|
||||
response = responses.msgs[0].content
|
||||
api_database = self._data['api']
|
||||
qa_pairs = self._data['eval']
|
||||
ast_database = self._data['ast']
|
||||
question_id = question['question_id']
|
||||
|
||||
# Evaluate response
|
||||
error, correct, hallucination = evaluate_response(
|
||||
response,
|
||||
question_id,
|
||||
dataset_name,
|
||||
api_database,
|
||||
qa_pairs,
|
||||
ast_database,
|
||||
)
|
||||
self._results.append(
|
||||
{
|
||||
"question": question,
|
||||
"agent_response": response,
|
||||
"correct": correct,
|
||||
"hallucination": hallucination,
|
||||
"error": str(error) if error else None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error in processing task: {question}: {e}"
|
||||
)
|
||||
self._results.append(
|
||||
{
|
||||
"question": question,
|
||||
"agent_response": None,
|
||||
"correct": False,
|
||||
"hallucination": False,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
agent.reset()
|
||||
|
||||
json_str = json.dumps(
|
||||
self._results[-1], indent=2, ensure_ascii=False
|
||||
)
|
||||
f.write(json_str + "\n")
|
||||
f.flush()
|
||||
|
||||
total = len(self._results)
|
||||
correct = sum(r["correct"] for r in self.results)
|
||||
hallucination = sum(r["hallucination"] for r in self.results)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"hallucination": hallucination,
|
||||
"accuracy": correct / total if total else "N/A",
|
||||
"hallucination rate": hallucination / total if total else "N/A",
|
||||
}
|
||||
|
||||
|
||||
# This code is modified from the
|
||||
# evaluators in the original repo
|
||||
# https://github.com/ShishirPatil/gorilla
|
||||
# Get all the subtrees given a root_node
|
||||
def get_all_sub_trees(root_node):
|
||||
node_stack = []
|
||||
sub_tree_sexp_list = []
|
||||
depth = 1
|
||||
# text = root_node.text
|
||||
node_stack.append([root_node, depth])
|
||||
while len(node_stack) != 0:
|
||||
cur_node, cur_depth = node_stack.pop()
|
||||
if cur_node.child_count > 0:
|
||||
sub_tree_sexp_list.append(
|
||||
[
|
||||
str(cur_node),
|
||||
cur_depth,
|
||||
cur_node,
|
||||
cur_node.children[0].text,
|
||||
]
|
||||
)
|
||||
else:
|
||||
sub_tree_sexp_list.append(
|
||||
[str(cur_node), cur_depth, cur_node, None]
|
||||
)
|
||||
for child_node in cur_node.children:
|
||||
if len(child_node.children) != 0:
|
||||
depth = cur_depth + 1
|
||||
node_stack.append([child_node, depth])
|
||||
return sub_tree_sexp_list
|
||||
|
||||
|
||||
# Parse the program into AST trees
|
||||
def ast_parse(candidate):
|
||||
PY_LANGUAGE = Language(tspython.language())
|
||||
parser = Parser(PY_LANGUAGE)
|
||||
|
||||
candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
|
||||
return candidate_tree
|
||||
|
||||
|
||||
# Get all the arguments in the ast tree
|
||||
def get_args(node, dataset_name):
|
||||
if node.child_count == 0:
|
||||
return []
|
||||
args_list = []
|
||||
if dataset_name == "huggingface":
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if "=" in child.text.decode():
|
||||
args_list.append(child.children[2].text)
|
||||
elif (
|
||||
child.text.decode() != "("
|
||||
and child.text.decode() != ")"
|
||||
and child.text.decode() != ","
|
||||
):
|
||||
args_list.append(child.text)
|
||||
elif dataset_name == "tensorflowhub":
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if (
|
||||
'model=' in child.text.decode()
|
||||
or 'model =' in child.text.decode()
|
||||
):
|
||||
args_list.append(child.children[2].text)
|
||||
elif (
|
||||
child.text.decode() != "("
|
||||
and child.text.decode() != ")"
|
||||
and child.text.decode() != ","
|
||||
):
|
||||
args_list.append(child.text)
|
||||
elif dataset_name == "torchhub":
|
||||
for child in node.children[0].children[0].children[1].children:
|
||||
if (
|
||||
"repo_or_dir" in child.text.decode()
|
||||
or "model" in child.text.decode()
|
||||
):
|
||||
args_list.append(child.children[2].text)
|
||||
return args_list
|
||||
|
||||
|
||||
# Check if there is an api match
|
||||
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
|
||||
for idx, base_tree in enumerate(base_tree_list):
|
||||
if base_tree.children[0].children[0].child_count == 0:
|
||||
continue
|
||||
api_name = base_tree.children[0].children[0].children[0].text
|
||||
for candidate_tree in candidate_subtree_list:
|
||||
if candidate_tree[3] == api_name:
|
||||
break
|
||||
# Now we have a sub-tree
|
||||
candidate_tree = candidate_tree[2]
|
||||
args_list = get_args(base_tree, dataset_name)
|
||||
if len(args_list) == 0:
|
||||
continue
|
||||
ast_match = True
|
||||
for arg in args_list:
|
||||
if (
|
||||
arg.decode().lstrip("'").rstrip("'")
|
||||
not in candidate_tree.text.decode()
|
||||
):
|
||||
ast_match = False
|
||||
break
|
||||
if ast_match:
|
||||
return idx
|
||||
return -1
|
||||
|
||||
|
||||
def evaluate_response(
|
||||
response, question_id, dataset_name, api_database, qa_pairs, ast_database
|
||||
):
|
||||
try:
|
||||
# Index the "api_call" domain
|
||||
output = response.split("api_call")
|
||||
if len(output) == 1:
|
||||
api_call = output[0]
|
||||
else:
|
||||
# Parse the output
|
||||
output = output[1].split("api_provider")[0]
|
||||
if ":" not in output:
|
||||
start = 0
|
||||
else:
|
||||
start = output.index(":")
|
||||
if ")" not in output:
|
||||
end = -2
|
||||
else:
|
||||
end = output.rindex(")")
|
||||
api_call = output[start + 2 : end + 1]
|
||||
|
||||
try:
|
||||
ast_tree = ast_parse(api_call)
|
||||
except Exception as parse_error:
|
||||
print(f"Error parsing api_call: {api_call}, error: {parse_error}")
|
||||
return parse_error, False, False
|
||||
# Search for a subtree
|
||||
ast_subtree_list = get_all_sub_trees(ast_tree)
|
||||
# Check which ast tree is matching
|
||||
database_index = ast_check(
|
||||
ast_subtree_list, ast_database, dataset_name
|
||||
)
|
||||
# We cannot index this ast in our database
|
||||
if database_index == -1:
|
||||
halluncination = True
|
||||
correct = False
|
||||
# We index our reference api_call
|
||||
ref_api_call = api_database[database_index]
|
||||
# Check for functionality
|
||||
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
||||
correct = True
|
||||
halluncination = False
|
||||
else:
|
||||
return None, False, False
|
||||
except Exception as e:
|
||||
print(f'Error parsing response: {response}, error: {e}')
|
||||
return e, False, False
|
||||
|
||||
return None, correct, halluncination
|
||||
152
camel/benchmarks/base.py
Normal file
152
camel/benchmarks/base.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseBenchmark(ABC):
|
||||
r"""Base class for benchmarks.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
||||
):
|
||||
r"""Initialize the benchmark.
|
||||
|
||||
Args:
|
||||
name (str): Name of the benchmark.
|
||||
data_dir (str): Path to the data directory.
|
||||
save_to (str): Path to save the results.
|
||||
processes (int): Number of processes to use for parallel
|
||||
processing. :(default: :obj:`1`)
|
||||
|
||||
"""
|
||||
self.name = name
|
||||
self.data_dir = Path(data_dir)
|
||||
self.processes = processes
|
||||
self.save_to = save_to
|
||||
if not self.data_dir.exists():
|
||||
logger.info(
|
||||
f"Data directory {data_dir} does not exist. Creating it."
|
||||
)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not self.data_dir.is_dir():
|
||||
raise NotADirectoryError(
|
||||
f"Data directory {data_dir} is not a directory"
|
||||
)
|
||||
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
||||
self._results: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
def download(self) -> "BaseBenchmark":
|
||||
r"""Download the benchmark data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
||||
r"""Load the benchmark data.
|
||||
|
||||
Args:
|
||||
force_download (bool): Whether to force download the data.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def train(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the training data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The training data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["train"]
|
||||
|
||||
@property
|
||||
def valid(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the validation data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The validation data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["valid"]
|
||||
|
||||
@property
|
||||
def test(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the test data.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The test data.
|
||||
"""
|
||||
if not self._data:
|
||||
logger.info("Data not loaded. Loading data.")
|
||||
self.load()
|
||||
return self._data["test"]
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
on: Literal["train", "valid", "test"],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "BaseBenchmark":
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The chat agent.
|
||||
on (str): The data split to run the benchmark on.
|
||||
randomize (bool): Whether to randomize the data.
|
||||
subset (int): The subset of the data to run the benchmark on.
|
||||
|
||||
Returns:
|
||||
BaseBenchmark: The benchmark instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def results(self) -> List[Dict[str, Any]]:
|
||||
r"""Get the results.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The results.
|
||||
"""
|
||||
return self._results
|
||||
482
camel/benchmarks/gaia.py
Normal file
482
camel/benchmarks/gaia.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.benchmarks.base import BaseBenchmark
|
||||
from camel.messages import BaseMessage
|
||||
from camel.retrievers.auto_retriever import AutoRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrieverProtocol(Protocol):
|
||||
r"""Protocol for the retriever class. Any retriever class implementing
|
||||
this protocol can be used in the benchmark class.
|
||||
"""
|
||||
|
||||
def retrieve(
|
||||
self, query: str, contents: List[str], **kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
r"""Retrieve the relevant content for the query.
|
||||
|
||||
Args:
|
||||
query (str): The query to retrieve the content for.
|
||||
contents (List[str]): The list of contents to search in.
|
||||
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The relevant content for the query.
|
||||
"""
|
||||
...
|
||||
|
||||
def reset(self, **kwargs) -> bool:
|
||||
r"""Reset the retriever.
|
||||
Some benchmarks may require resetting the retriever
|
||||
after each query.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
bool: True if the reset was successful, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultGAIARetriever(AutoRetriever):
|
||||
r"""Default retriever for the GAIA benchmark.
|
||||
This retriever uses AutoRetriever in camel to retrieve the content based on
|
||||
the query.
|
||||
"""
|
||||
|
||||
def retrieve(
|
||||
self, query: str, contents: List[str], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
r"""Retrieve the content based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
contents (List[str]): The list of contents to search from.
|
||||
**kwargs (Any): The keyword arguments to pass to the
|
||||
retriever.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The retrieved content.
|
||||
"""
|
||||
return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
def reset(self, **kwargs: Any) -> bool:
|
||||
r"""Reset the retriever.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): The keyword arguments to pass to the
|
||||
retriever.
|
||||
|
||||
Returns:
|
||||
bool: Whether the reset was successful.
|
||||
"""
|
||||
path = Path(self.vector_storage_local_path or os.getcwd())
|
||||
task_id = str(kwargs.get("task_id", uuid.uuid4()))
|
||||
retriever_dir = path / task_id
|
||||
if not retriever_dir.exists():
|
||||
try:
|
||||
retriever_dir.mkdir(parents=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in creating directory: " + f"{retriever_dir}: {e!s}"
|
||||
)
|
||||
return False
|
||||
self.vector_storage_local_path = str(retriever_dir)
|
||||
return True
|
||||
|
||||
|
||||
class GAIABenchmark(BaseBenchmark):
|
||||
r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
|
||||
Assistants"
|
||||
<https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
retriever (Optional[RetrieverProtocol]): The retriever to use.
|
||||
(default: :obj:`None`)
|
||||
processes (int, optional): The number of processes to use.
|
||||
(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
save_to: str,
|
||||
retriever: Optional[RetrieverProtocol] = None,
|
||||
processes: int = 1,
|
||||
):
|
||||
r"""Initialize the GAIA benchmark.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
retriever (Optional[RetrieverProtocol], optional): The retriever to
|
||||
use. (default: :obj:`None`)
|
||||
processes (int, optional): The number of processes to use for
|
||||
parallel processing. (default: :obj:`1`)
|
||||
"""
|
||||
super().__init__("gaia", data_dir, save_to, processes)
|
||||
self.retriever = retriever or DefaultGAIARetriever()
|
||||
|
||||
def download(self):
|
||||
r"""Download the GAIA dataset."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="gaia-benchmark/GAIA",
|
||||
repo_type="dataset",
|
||||
local_dir=self.data_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
def load(self, force_download=False):
|
||||
r"""Load the GAIA dataset.
|
||||
|
||||
Args:
|
||||
force_download (bool, optional): Whether to
|
||||
force download the data.
|
||||
"""
|
||||
if force_download:
|
||||
logger.info("Force downloading data.")
|
||||
self.download()
|
||||
|
||||
# Define validation and test directories
|
||||
valid_dir = self.data_dir / "2023/validation"
|
||||
test_dir = self.data_dir / "2023/test"
|
||||
|
||||
# Check if directories exist; if not, download the data
|
||||
if not valid_dir.is_dir() or not test_dir.is_dir():
|
||||
logger.info("Data not found. Downloading data.")
|
||||
self.download()
|
||||
|
||||
# Load metadata for both validation and test datasets
|
||||
for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
|
||||
self._data[label] = []
|
||||
with open(path / "metadata.jsonl", "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
if data["task_id"] == "0-0-0-0-0":
|
||||
continue
|
||||
if data["file_name"]:
|
||||
data["file_name"] = path / data["file_name"]
|
||||
self._data[label].append(data)
|
||||
return self
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
r"""Get the training set."""
|
||||
raise NotImplementedError("GAIA does not have a training set.")
|
||||
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
on: Literal["train", "valid", "test"],
|
||||
level: Union[int, List[int], Literal["all"]],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent to run the benchmark.
|
||||
on (Literal["valid", "test"]): The set to run the benchmark.
|
||||
level (Union[int, List[int], Literal["all"]]): The level to run
|
||||
the benchmark.
|
||||
randomize (bool, optional): Whether to randomize the data.
|
||||
(default: :obj:`False`)
|
||||
subset (Optional[int], optional): The subset of data to run.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The results of the benchmark.
|
||||
"""
|
||||
# Validate inputs
|
||||
if on not in ["valid", "test"]:
|
||||
raise ValueError(
|
||||
f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
|
||||
)
|
||||
|
||||
levels = (
|
||||
[1, 2, 3]
|
||||
if level == "all"
|
||||
else [level]
|
||||
if isinstance(level, int)
|
||||
else level
|
||||
)
|
||||
if not all(
|
||||
isinstance(level, int) and level in [1, 2, 3] for level in levels
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid value for `level`: {level}, expected 1, 2, 3 "
|
||||
"or 'all'."
|
||||
)
|
||||
|
||||
logger.info(f"Running benchmark on {on} set at levels {levels}.")
|
||||
datas = [data for data in self._data[on] if data["Level"] in levels]
|
||||
|
||||
# Shuffle and subset data if necessary
|
||||
if randomize:
|
||||
random.shuffle(datas)
|
||||
if subset:
|
||||
datas = datas[:subset]
|
||||
|
||||
logger.info(f"Number of tasks: {len(datas)}")
|
||||
|
||||
# Initialize results storage
|
||||
self._results = []
|
||||
|
||||
# Process tasks
|
||||
with open(self.save_to, "w") as f:
|
||||
for task in tqdm(datas, desc="Running"):
|
||||
if not self._prepare_task(task):
|
||||
continue
|
||||
|
||||
try:
|
||||
result = agent.step(self._create_user_message(task))
|
||||
self._process_result(agent, task, result, f)
|
||||
except Exception as e:
|
||||
self._handle_error(task, e, f)
|
||||
finally:
|
||||
agent.reset()
|
||||
|
||||
return self._generate_summary()
|
||||
|
||||
def _prepare_task(self, task: Dict[str, Any]) -> bool:
|
||||
r"""Prepare the task by validating and enriching its data."""
|
||||
if task["file_name"]:
|
||||
file_path = Path(task["file_name"])
|
||||
if not file_path.exists():
|
||||
logger.info(
|
||||
f"Skipping task because file not found: {file_path}"
|
||||
)
|
||||
return False
|
||||
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
||||
if not self.retriever.reset(task_id=task["task_id"]):
|
||||
return False
|
||||
retrieved_info = self.retriever.retrieve(
|
||||
query=task["Question"], contents=[task["file_name"]]
|
||||
)
|
||||
retrieved_content = [
|
||||
item["text"]
|
||||
for item in retrieved_info.get("Retrieved Context", [])
|
||||
]
|
||||
if retrieved_content:
|
||||
task["Question"] += "\n" + "\n".join(retrieved_content)
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipping task due to unsupported file "
|
||||
f"format: {file_path.suffix}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
|
||||
r"""Create a user message from a task."""
|
||||
return BaseMessage.make_user_message(
|
||||
role_name="User",
|
||||
content=task["Question"],
|
||||
)
|
||||
|
||||
def _process_result(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
task: Dict[str, Any],
|
||||
result: Any,
|
||||
file_obj: Any,
|
||||
) -> None:
|
||||
r"""Process and store the result of a task."""
|
||||
model_answer = self.get_final_answer(result.msgs[0].content)
|
||||
final_answer = task["Final answer"]
|
||||
score = self.question_scorer(model_answer, final_answer)
|
||||
tool_calls = result.info.get("tool_calls", [])
|
||||
|
||||
result_data = {
|
||||
"task_id": task["task_id"],
|
||||
"question": task["Question"],
|
||||
"level": task["Level"],
|
||||
"model_answer": model_answer,
|
||||
"ground_truth": final_answer,
|
||||
"tool_calls": [tool.model_dump() for tool in tool_calls],
|
||||
"error": None,
|
||||
"score": int(score),
|
||||
"history": agent.memory.get_context(),
|
||||
}
|
||||
self._results.append(result_data)
|
||||
file_obj.write(
|
||||
json.dumps(result_data, indent=2) + "\n", ensure_ascii=False
|
||||
)
|
||||
file_obj.flush()
|
||||
|
||||
def _handle_error(
|
||||
self, task: Dict[str, Any], error: Exception, file_obj: Any
|
||||
) -> None:
|
||||
r"""Handle errors encountered during task processing."""
|
||||
logger.warning(f"Error processing task {task['task_id']}: {error}")
|
||||
error_data = {
|
||||
"task_id": task["task_id"],
|
||||
"question": task["Question"],
|
||||
"level": task["Level"],
|
||||
"model_answer": "ERROR",
|
||||
"ground_truth": task["Final answer"],
|
||||
"tool_calls": [],
|
||||
"error": str(error),
|
||||
"score": 0,
|
||||
}
|
||||
self._results.append(error_data)
|
||||
file_obj.write(
|
||||
json.dumps(error_data, indent=2) + "\n", ensure_ascii=False
|
||||
)
|
||||
file_obj.flush()
|
||||
|
||||
def _generate_summary(self) -> Dict[str, Any]:
|
||||
r"""Generate and return a summary of the benchmark results."""
|
||||
return {
|
||||
"total": len(self._results),
|
||||
"correct": sum(result["score"] for result in self._results),
|
||||
"results": self._results,
|
||||
}
|
||||
|
||||
def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
|
||||
r"""Scorer for the GAIA benchmark.
|
||||
https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
|
||||
scorer.py
|
||||
|
||||
Args:
|
||||
model_answer (str): The model answer.
|
||||
ground_truth (str): The ground truth answer.
|
||||
|
||||
Returns:
|
||||
bool: The score of the model
|
||||
"""
|
||||
|
||||
def is_float(element: Any) -> bool:
|
||||
try:
|
||||
float(element)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if is_float(ground_truth):
|
||||
logger.info(f"Evaluating {model_answer} as a number.")
|
||||
normalized_answer = self.normalize_number_str(model_answer)
|
||||
return normalized_answer == float(ground_truth)
|
||||
|
||||
elif any(char in ground_truth for char in [",", ";"]):
|
||||
logger.info(
|
||||
f"Evaluating {model_answer} as a comma separated list."
|
||||
)
|
||||
gt_elems = self.split_string(ground_truth)
|
||||
ma_elems = self.split_string(model_answer)
|
||||
|
||||
if len(gt_elems) != len(ma_elems):
|
||||
logger.warning(
|
||||
"Answer lists have different lengths, returning False.",
|
||||
UserWarning,
|
||||
)
|
||||
return False
|
||||
|
||||
comparisons = []
|
||||
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
||||
if is_float(gt_elem):
|
||||
normalized_ma_elem = self.normalize_number_str(ma_elem)
|
||||
comparisons.append(normalized_ma_elem == float(gt_elem))
|
||||
else:
|
||||
ma_elem = self.normalize_str(ma_elem, remove_punct=False)
|
||||
gt_elem = self.normalize_str(gt_elem, remove_punct=False)
|
||||
comparisons.append(ma_elem == gt_elem)
|
||||
return all(comparisons)
|
||||
else:
|
||||
logger.info(f"Evaluating {model_answer} as a string.")
|
||||
ma_elem = self.normalize_str(model_answer)
|
||||
gt_elem = self.normalize_str(ground_truth)
|
||||
return ma_elem == gt_elem
|
||||
|
||||
def normalize_number_str(self, number_str: str) -> float:
|
||||
for char in ["$", "%", ","]:
|
||||
number_str = number_str.replace(char, "")
|
||||
try:
|
||||
return float(number_str)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"String {number_str} cannot be normalized to number str."
|
||||
)
|
||||
return float("inf")
|
||||
|
||||
def split_string(
|
||||
self, s: str, char_list: Optional[List[str]] = None
|
||||
) -> list[str]:
|
||||
r"""Split a string based on a list of characters.
|
||||
|
||||
Args:
|
||||
s (str): The string to split.
|
||||
char_list (Optional[List[str]], optional): T
|
||||
he list of characters to split on.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
if char_list is None:
|
||||
char_list = [",", ";"]
|
||||
pattern = f"[{''.join(char_list)}]"
|
||||
return re.split(pattern, s)
|
||||
|
||||
def normalize_str(self, input_str, remove_punct=True) -> str:
|
||||
r"""Normalize a string.
|
||||
|
||||
Args:
|
||||
input_str: The input string to normalize.
|
||||
remove_punct: Whether to remove punctuation.
|
||||
|
||||
Returns:
|
||||
str: The normalized string.
|
||||
"""
|
||||
no_spaces = re.sub(r"\s", "", input_str)
|
||||
if remove_punct:
|
||||
translator = str.maketrans("", "", string.punctuation)
|
||||
return no_spaces.lower().translate(translator)
|
||||
else:
|
||||
return no_spaces.lower()
|
||||
|
||||
def get_final_answer(self, content: str) -> str:
|
||||
r"""Get the final answer from the content.
|
||||
|
||||
Args:
|
||||
content (str): The content to extract the final answer from.
|
||||
|
||||
Returns:
|
||||
str: The final answer.
|
||||
"""
|
||||
final_answer_index = content.find("FINAL ANSWER")
|
||||
if final_answer_index == -1:
|
||||
return "FINAL ANSWER not found"
|
||||
start_index = final_answer_index + len("FINAL ANSWER: ")
|
||||
final_answer_content = content[start_index:].strip()
|
||||
return final_answer_content
|
||||
517
camel/benchmarks/nexus.py
Normal file
517
camel/benchmarks/nexus.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.benchmarks.base import BaseBenchmark
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define the data class
|
||||
@dataclass
|
||||
class NexusSample:
|
||||
r"""Nexus benchmark dataset sample."""
|
||||
|
||||
input: str
|
||||
output: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NexusTool:
|
||||
r"""Nexus benchmark tool"""
|
||||
|
||||
function_calls: str
|
||||
descriptions: str
|
||||
|
||||
|
||||
dataset_mapping = {
|
||||
"NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
|
||||
"VirusTotal": "Nexusflow/VirusTotalBenchmark",
|
||||
"PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
|
||||
"ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
|
||||
"OTX": "Nexusflow/OTXAPIBenchmark",
|
||||
"VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
|
||||
"VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
|
||||
"NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
|
||||
}
|
||||
|
||||
TOOL_CALLING_PROMPT = """
|
||||
You are given multiple functions and a user query.
|
||||
|
||||
Please proceed with generating a function call for the function \
|
||||
with the proper arguments that best answers the given prompt.
|
||||
|
||||
Respond with nothing but the function call ONLY, such that I can \
|
||||
directly execute your function call without any post processing \
|
||||
necessary from my end. Do not use variables.
|
||||
If there are more than two function calls, separate them with a semicolon (;).
|
||||
|
||||
{tools}
|
||||
|
||||
Question: {input}
|
||||
"""
|
||||
|
||||
|
||||
class NexusBenchmark(BaseBenchmark):
|
||||
r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
|
||||
Function Calling Benchmark`
|
||||
<https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use.
|
||||
(default: :obj:`1`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
save_to: str,
|
||||
processes: int = 1,
|
||||
):
|
||||
r"""Initialize the Nexus Function Calling benchmark.
|
||||
|
||||
Args:
|
||||
data_dir (str): The directory to save the data.
|
||||
save_to (str): The file to save the results.
|
||||
processes (int, optional): The number of processes to use for
|
||||
parallel processing. (default: :obj:`1`)
|
||||
"""
|
||||
super().__init__("nexus", data_dir, save_to, processes)
|
||||
self._data: List[NexusSample] = [] # type: ignore[assignment]
|
||||
|
||||
def download(self):
|
||||
r"""Download the Nexus Functional Calling Benchmark dataset."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
for dataset_name, repo_id in dataset_mapping.items():
|
||||
local_dir = self.data_dir / dataset_name
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
||||
r"""Load the Nexus Benchmark dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the specific dataset to be loaded.
|
||||
force_download (bool): Whether to force download the data.
|
||||
"""
|
||||
|
||||
def _load_csv_data(dataset_dir: Path) -> List:
|
||||
r"""Load datasets from CSV files."""
|
||||
dataset = []
|
||||
for file_name in os.listdir(dataset_dir):
|
||||
file_path = dataset_dir / file_name
|
||||
if file_name.endswith(".csv"):
|
||||
data = pd.read_csv(file_path)
|
||||
for _, sample in data.iterrows():
|
||||
dataset.append(
|
||||
NexusSample(
|
||||
sample["Input"], "".join(sample["Output"])
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(f"Skipping unsupported file: {file_name}")
|
||||
return dataset
|
||||
|
||||
def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
|
||||
r"""Load datasets from Parquet files."""
|
||||
dataset = []
|
||||
if not data_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Data directory '{data_dir}' does not exist."
|
||||
)
|
||||
|
||||
for file_name in os.listdir(data_dir):
|
||||
file_path = data_dir / file_name
|
||||
if file_name.endswith(".parquet"):
|
||||
data = pd.read_parquet(file_path)
|
||||
dataset.extend(_process_parquet_data(data, dataset_name))
|
||||
continue
|
||||
|
||||
logger.warning(f"Skipping unsupported file: {file_name}")
|
||||
|
||||
return dataset
|
||||
|
||||
def _process_parquet_data(
|
||||
data: pd.DataFrame, dataset_name: str
|
||||
) -> List:
|
||||
r"""Process data from Parquet files based on dataset name."""
|
||||
dataset: List = []
|
||||
dataset_handlers = {
|
||||
"NVDLibrary": _process_nvdlibrary,
|
||||
"VirusTotal": _process_simple,
|
||||
"PlacesAPI": _process_simple,
|
||||
"ClimateAPI": _process_simple,
|
||||
"OTX": _process_simple,
|
||||
"VirusTotal-NestedCalls": _process_nested_calls,
|
||||
"VirusTotal-ParallelCalls": _process_parallel_calls,
|
||||
}
|
||||
|
||||
if dataset_name not in dataset_handlers:
|
||||
logger.warning(
|
||||
f"No specific handler for dataset: {dataset_name}"
|
||||
)
|
||||
return dataset
|
||||
|
||||
handler = dataset_handlers[dataset_name]
|
||||
for _, sample in data.iterrows():
|
||||
processed_sample = handler(sample)
|
||||
if processed_sample:
|
||||
dataset.append(processed_sample)
|
||||
return dataset
|
||||
|
||||
def _process_nvdlibrary(sample) -> NexusSample:
|
||||
r"""Process samples for the NVDLibrary dataset."""
|
||||
return NexusSample(
|
||||
sample["Input"], sample["Output"].replace("r = nvdlib.", "")
|
||||
)
|
||||
|
||||
def _process_simple(sample) -> NexusSample:
|
||||
r"""Process samples for simple datasets (e.g., VirusTotal)."""
|
||||
return NexusSample(sample["Input"], sample["Output"])
|
||||
|
||||
def _process_nested_calls(sample) -> Union[NexusSample, None]:
|
||||
r"""Process samples for VirusTotal-NestedCalls dataset."""
|
||||
if len(sample["fncall"]) == 1:
|
||||
return NexusSample(
|
||||
sample["generated_question"], "".join(sample["fncall"])
|
||||
)
|
||||
return None
|
||||
|
||||
def _process_parallel_calls(sample) -> Union[NexusSample, None]:
|
||||
r"""Process samples for VirusTotal-ParallelCalls dataset."""
|
||||
if len(sample["fncall"]) > 1:
|
||||
return NexusSample(
|
||||
sample["generated_question"], "; ".join(sample["fncall"])
|
||||
)
|
||||
return None
|
||||
|
||||
if force_download:
|
||||
logger.info("Force downloading data.")
|
||||
self.download()
|
||||
|
||||
# Validate dataset name
|
||||
if dataset_name not in dataset_mapping:
|
||||
available_datasets = list(dataset_mapping.keys())
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' is not recognized. "
|
||||
f"Available datasets: {available_datasets}"
|
||||
)
|
||||
|
||||
# Get the dataset directory
|
||||
dataset_dir = self.data_dir / dataset_name
|
||||
if not dataset_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"The dataset directory for '{dataset_name}' \
|
||||
does not exist at {dataset_dir}. "
|
||||
"Please download it first."
|
||||
)
|
||||
|
||||
# Load the dataset
|
||||
if dataset_name == "NVDLibrary-NestedCalls":
|
||||
self._data = _load_csv_data(dataset_dir)
|
||||
else:
|
||||
self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
r"""Get the training set."""
|
||||
raise NotImplementedError(
|
||||
"Nexus Functional Calling has only a single 'train' set."
|
||||
)
|
||||
|
||||
def run( # type: ignore[override, return]
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
task: Literal[
|
||||
"NVDLibrary",
|
||||
"VirusTotal",
|
||||
"OTX",
|
||||
"PlacesAPI",
|
||||
"ClimateAPI",
|
||||
"VirusTotal-ParallelCalls",
|
||||
"VirusTotal-NestedCalls",
|
||||
"NVDLibrary-NestedCalls",
|
||||
],
|
||||
randomize: bool = False,
|
||||
subset: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Run the benchmark.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent to run the benchmark.
|
||||
task (Literal["NVDLibrary", "VirusTotal", "OTX",
|
||||
"PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
|
||||
"VirusTotal-NestedCalls",
|
||||
"NVDLibrary-NestedCalls"]): The task to run the benchmark.
|
||||
randomize (bool, optional): Whether to randomize the data.
|
||||
(default: :obj:`False`)
|
||||
subset (Optional[int], optional): The subset of data to run.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The results of the benchmark.
|
||||
"""
|
||||
|
||||
if task not in dataset_mapping:
|
||||
raise ValueError(f"Invalid value for dataset: {task}.")
|
||||
|
||||
logger.info(f"Running Nexus Function Calling benchmark on {task}.")
|
||||
self.load(task)
|
||||
datas = self._data
|
||||
|
||||
# Shuffle and subset data if necessary
|
||||
if randomize:
|
||||
random.shuffle(datas)
|
||||
if subset:
|
||||
datas = datas[:subset]
|
||||
|
||||
logger.info(f"Number of tasks: {len(datas)}")
|
||||
|
||||
# Initialize results storage
|
||||
self._results = []
|
||||
|
||||
# Process samples
|
||||
tools = construct_tool_descriptions(task)
|
||||
with open(self.save_to, "w") as f:
|
||||
for sample in tqdm(datas, desc="Running"):
|
||||
prompt = construct_prompt(input=sample.input, tools=tools)
|
||||
ground_truth_call = sample.output
|
||||
try:
|
||||
# Generate response
|
||||
response = agent.step(prompt)
|
||||
agent_call = response.msgs[0].content
|
||||
|
||||
# Evaluate response
|
||||
if agent_call:
|
||||
result = compare_function_calls(
|
||||
agent_call=agent_call,
|
||||
ground_truth_call=ground_truth_call,
|
||||
)
|
||||
self._results.append(
|
||||
{
|
||||
"input": sample.input,
|
||||
"agent_call": agent_call,
|
||||
"ground_truth_call": ground_truth_call,
|
||||
"result": result,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in processing task: {sample.input}")
|
||||
self._results.append(
|
||||
{
|
||||
"input": sample.input,
|
||||
"agent_call": None,
|
||||
"ground_truth_call": ground_truth_call,
|
||||
"result": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
agent.reset()
|
||||
|
||||
json_str = json.dumps(
|
||||
self._results[-1], indent=2, ensure_ascii=False
|
||||
)
|
||||
f.write(json_str + "\n")
|
||||
f.flush()
|
||||
|
||||
total = len(self._results)
|
||||
correct = sum(r["result"] for r in self._results)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"accuracy": correct / total,
|
||||
}
|
||||
|
||||
|
||||
# Utility functions
|
||||
def construct_tool_descriptions(dataset_name: str) -> str:
|
||||
r"""Construct tool descriptions from function definitions and
|
||||
descriptions."""
|
||||
tool_dataset_mapping = {
|
||||
"NVDLibrary": "CVECPE",
|
||||
"VirusTotal": "VirusTotal",
|
||||
"PlacesAPI": "Places",
|
||||
"ClimateAPI": "Climate",
|
||||
"OTX": "OTX",
|
||||
"VirusTotal-NestedCalls": "VT_Multi (Nested)",
|
||||
"VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
|
||||
"NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
|
||||
}
|
||||
|
||||
if dataset_name not in tool_dataset_mapping:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' is not recognized. "
|
||||
f"Available datasets: {list(dataset_mapping.keys())}"
|
||||
)
|
||||
|
||||
# Load the dataset based on the dataset name
|
||||
dataset = load_dataset(
|
||||
"Nexusflow/Function_Call_Definitions",
|
||||
name=tool_dataset_mapping[dataset_name],
|
||||
)["train"]
|
||||
|
||||
# Construct tool descriptions
|
||||
tools = [
|
||||
NexusTool(tool["function_calls"], tool["descriptions"])
|
||||
for tool in dataset
|
||||
]
|
||||
|
||||
# Generate the tool prompt
|
||||
tool_prompt = "".join(
|
||||
f"Function:\ndef {tool.function_calls}:\n"
|
||||
+ "\"\"\"\n"
|
||||
+ f"{tool.descriptions}\n"
|
||||
+ "\"\"\"\n"
|
||||
for tool in tools
|
||||
)
|
||||
|
||||
return tool_prompt
|
||||
|
||||
|
||||
def construct_prompt(input: str, tools: str) -> str:
|
||||
r"Construct prompt from tools and input."
|
||||
return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
|
||||
|
||||
|
||||
# Functions for function call evaluation
|
||||
def parse_function_call(
|
||||
call: str,
|
||||
) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
|
||||
r"""Parse a function call string to extract the function name,
|
||||
positional arguments, and keyword arguments, including
|
||||
nested function calls.
|
||||
|
||||
Args:
|
||||
call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
|
||||
|
||||
Returns:
|
||||
tuple: (function_name (str), positional_args (list),
|
||||
keyword_args (dict)) or (None, None, None).
|
||||
"""
|
||||
|
||||
def preprocess_input(call: str) -> str:
|
||||
r"""Remove formatting like code blocks and whitespace."""
|
||||
if call.strip().startswith("```python"):
|
||||
call = call.strip().removeprefix("```python").removesuffix("```")
|
||||
return textwrap.dedent(call).strip()
|
||||
|
||||
def evaluate_arg(arg):
|
||||
r"""Recursively evaluate arguments, including nested calls."""
|
||||
if isinstance(arg, ast.Call):
|
||||
# Recursively parse nested calls
|
||||
func_name, args, kwargs = parse_function_call(ast.unparse(arg))
|
||||
return func_name, args, kwargs
|
||||
elif isinstance(
|
||||
arg, ast.Constant
|
||||
): # Handle literals like numbers, strings, etc.
|
||||
return arg.value
|
||||
elif isinstance(arg, ast.List): # Handle list literals
|
||||
return [evaluate_arg(el) for el in arg.elts]
|
||||
elif isinstance(arg, ast.Dict): # Handle dictionary literals
|
||||
return {
|
||||
evaluate_arg(k): evaluate_arg(v)
|
||||
for k, v in zip(arg.keys, arg.values)
|
||||
}
|
||||
elif isinstance(arg, ast.Tuple): # Handle tuple literals
|
||||
return tuple(evaluate_arg(el) for el in arg.elts)
|
||||
else:
|
||||
return ast.literal_eval(arg) # Safely evaluate other types
|
||||
|
||||
call = preprocess_input(call)
|
||||
parsed_calls = []
|
||||
|
||||
try:
|
||||
# Parse the string into an AST
|
||||
parsed_calls = call.split(";")
|
||||
for single_call in parsed_calls:
|
||||
tree = ast.parse(single_call, mode='eval')
|
||||
|
||||
# Ensure it's a function call
|
||||
if isinstance(tree.body, ast.Call):
|
||||
# Extract function name
|
||||
if isinstance(
|
||||
tree.body.func, ast.Name
|
||||
): # Simple function call
|
||||
func_name = tree.body.func.id
|
||||
elif isinstance(
|
||||
tree.body.func, ast.Attribute
|
||||
): # Attribute function call
|
||||
func_name = (
|
||||
f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported function call: {call}")
|
||||
|
||||
# Extract positional arguments
|
||||
args = [evaluate_arg(arg) for arg in tree.body.args]
|
||||
|
||||
# Extract keyword arguments
|
||||
kwargs: Dict[str, Any] = {
|
||||
kw.arg: evaluate_arg(kw.value)
|
||||
for kw in tree.body.keywords
|
||||
if kw.arg is not None
|
||||
}
|
||||
logger.info("Valid call.")
|
||||
return func_name, args, kwargs
|
||||
else:
|
||||
raise ValueError(f"Not a valid function call: {call}")
|
||||
except Exception as e:
|
||||
logger.info(f"Error parsing call: {call}, {e}")
|
||||
return None, None, None
|
||||
|
||||
|
||||
def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
|
||||
r"""Compare the function name and arguments of
|
||||
agent_call and ground_truth_call.
|
||||
Args:
|
||||
agent_call (str): Function call by agent.
|
||||
ground_truth_call (str): Ground truth function call.
|
||||
|
||||
Returns:
|
||||
- `True` if the function names and arguments match.
|
||||
- `False` otherwise.
|
||||
"""
|
||||
# Parse both calls
|
||||
agent_parsed = parse_function_call(agent_call)
|
||||
gt_parsed = parse_function_call(ground_truth_call)
|
||||
|
||||
if agent_parsed and gt_parsed:
|
||||
return agent_parsed == gt_parsed
|
||||
else:
|
||||
return False
|
||||
333
camel/benchmarks/ragbench.py
Normal file
333
camel/benchmarks/ragbench.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.benchmarks import BaseBenchmark
|
||||
from camel.logger import get_logger
|
||||
from camel.retrievers import AutoRetriever
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RagasFields:
|
||||
r"""Constants for RAGAS evaluation field names."""
|
||||
|
||||
INPUT_CONTEXT = "contexts"
|
||||
INPUT_QUESTION = "question"
|
||||
INPUT_ANSWER = "answer"
|
||||
|
||||
|
||||
def annotate_dataset(
|
||||
dataset: Dataset,
|
||||
context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
|
||||
answer_call: Optional[Callable[[Dict[str, Any]], str]],
|
||||
) -> Dataset:
|
||||
r"""Annotate the dataset by adding context and answers using the provided
|
||||
functions.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The input dataset to annotate.
|
||||
context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
|
||||
Function to generate context for each example.
|
||||
answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
|
||||
generate answer for each example.
|
||||
|
||||
Returns:
|
||||
Dataset: The annotated dataset with added contexts and/or answers.
|
||||
"""
|
||||
|
||||
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if context_call:
|
||||
example["contexts"] = context_call(example)
|
||||
if answer_call:
|
||||
example["answer"] = answer_call(example)
|
||||
return example
|
||||
|
||||
return dataset.map(process_example)
|
||||
|
||||
|
||||
def rmse(
|
||||
input_trues: Sequence[float],
|
||||
input_preds: Sequence[float],
|
||||
) -> Optional[float]:
|
||||
r"""Calculate Root Mean Squared Error (RMSE).
|
||||
|
||||
Args:
|
||||
input_trues (Sequence[float]): Ground truth values.
|
||||
input_preds (Sequence[float]): Predicted values.
|
||||
|
||||
Returns:
|
||||
Optional[float]: RMSE value, or None if inputs have different lengths.
|
||||
"""
|
||||
if len(input_trues) != len(input_preds):
|
||||
logger.warning("Input lengths mismatch in RMSE calculation")
|
||||
return None
|
||||
|
||||
trues = np.array(input_trues)
|
||||
preds = np.array(input_preds, dtype=float)
|
||||
|
||||
# Ignore NaN values in predictions
|
||||
eval_idx = ~np.isnan(preds)
|
||||
if not np.any(eval_idx):
|
||||
logger.warning("No valid predictions for RMSE calculation")
|
||||
return None
|
||||
|
||||
trues = trues[eval_idx]
|
||||
preds = preds[eval_idx]
|
||||
|
||||
return float(np.sqrt(np.mean((preds - trues) ** 2)))
|
||||
|
||||
|
||||
def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
|
||||
r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
|
||||
|
||||
Args:
|
||||
trues (Sequence[bool]): Ground truth binary values.
|
||||
preds (Sequence[float]): Predicted probability values.
|
||||
|
||||
Returns:
|
||||
float: AUROC score.
|
||||
"""
|
||||
from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
|
||||
|
||||
eval_idx = ~np.isnan(preds)
|
||||
if not np.any(eval_idx):
|
||||
logger.warning("No valid predictions for AUROC calculation")
|
||||
return 0.5 # Return random classifier score
|
||||
|
||||
return float(
|
||||
roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
|
||||
)
|
||||
|
||||
|
||||
def ragas_calculate_metrics(
|
||||
dataset: Dataset,
|
||||
pred_context_relevance_field: Optional[str],
|
||||
pred_faithfulness_field: Optional[str],
|
||||
metrics_to_evaluate: Optional[List[str]] = None,
|
||||
ground_truth_context_relevance_field: str = "relevance_score",
|
||||
ground_truth_faithfulness_field: str = "adherence_score",
|
||||
) -> Dict[str, Optional[float]]:
|
||||
r"""Calculate RAGAS evaluation metrics.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset containing predictions and ground truth.
|
||||
pred_context_relevance_field (Optional[str]): Field name for predicted
|
||||
context relevance.
|
||||
pred_faithfulness_field (Optional[str]): Field name for predicted
|
||||
faithfulness.
|
||||
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
||||
ground_truth_context_relevance_field (str): Field name for ground truth
|
||||
relevance.
|
||||
ground_truth_faithfulness_field (str): Field name for ground truth
|
||||
adherence.
|
||||
|
||||
Returns:
|
||||
Dict[str, Optional[float]]: Dictionary of calculated metrics.
|
||||
"""
|
||||
metrics_to_evaluate = metrics_to_evaluate or [
|
||||
"context_relevancy",
|
||||
"faithfulness",
|
||||
]
|
||||
calculated_metrics: Dict[str, Optional[float]] = {}
|
||||
|
||||
if (
|
||||
"context_relevancy" in metrics_to_evaluate
|
||||
and pred_context_relevance_field
|
||||
):
|
||||
trues_relevance = dataset[ground_truth_context_relevance_field]
|
||||
preds_relevance = dataset[pred_context_relevance_field]
|
||||
calculated_metrics["relevance_rmse"] = rmse(
|
||||
trues_relevance, preds_relevance
|
||||
)
|
||||
|
||||
if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
|
||||
trues_hallucination = ~np.array(
|
||||
dataset[ground_truth_faithfulness_field]
|
||||
)
|
||||
preds_hallucination = 1 - np.array(
|
||||
dataset[pred_faithfulness_field], dtype=float
|
||||
)
|
||||
calculated_metrics["hallucination_auroc"] = auroc(
|
||||
trues_hallucination.tolist(), preds_hallucination.tolist()
|
||||
)
|
||||
|
||||
return calculated_metrics
|
||||
|
||||
|
||||
def ragas_evaluate_dataset(
|
||||
dataset: Dataset,
|
||||
contexts_field_name: Optional[str],
|
||||
answer_field_name: Optional[str],
|
||||
metrics_to_evaluate: Optional[List[str]] = None,
|
||||
) -> Dataset:
|
||||
r"""Evaluate the dataset using RAGAS metrics.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Input dataset to evaluate.
|
||||
contexts_field_name (Optional[str]): Field name containing contexts.
|
||||
answer_field_name (Optional[str]): Field name containing answers.
|
||||
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
||||
|
||||
Returns:
|
||||
Dataset: Dataset with added evaluation metrics.
|
||||
"""
|
||||
from ragas import evaluate # type: ignore[import]
|
||||
from ragas.metrics import ( # type: ignore[import]
|
||||
context_relevancy,
|
||||
faithfulness,
|
||||
)
|
||||
|
||||
metrics_to_evaluate = metrics_to_evaluate or [
|
||||
"context_relevancy",
|
||||
"faithfulness",
|
||||
]
|
||||
|
||||
# Rename fields if necessary
|
||||
if (
|
||||
contexts_field_name
|
||||
and contexts_field_name != RagasFields.INPUT_CONTEXT
|
||||
):
|
||||
dataset = dataset.rename_column(
|
||||
contexts_field_name, RagasFields.INPUT_CONTEXT
|
||||
)
|
||||
if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
|
||||
dataset = dataset.rename_column(
|
||||
answer_field_name, RagasFields.INPUT_ANSWER
|
||||
)
|
||||
|
||||
metrics = []
|
||||
if "context_relevancy" in metrics_to_evaluate:
|
||||
metrics.append(context_relevancy)
|
||||
if "faithfulness" in metrics_to_evaluate:
|
||||
metrics.append(faithfulness)
|
||||
|
||||
ragas_result = evaluate(dataset, metrics=metrics)
|
||||
return Dataset.from_pandas(ragas_result.to_pandas())
|
||||
|
||||
|
||||
class RAGBenchBenchmark(BaseBenchmark):
|
||||
r"""RAGBench Benchmark for evaluating RAG performance.
|
||||
|
||||
This benchmark uses the rungalileo/ragbench dataset to evaluate
|
||||
retrieval-augmented generation (RAG) systems. It measures context
|
||||
relevancy and faithfulness metrics as described in
|
||||
https://arxiv.org/abs/2407.11005.
|
||||
|
||||
Args:
|
||||
processes (int, optional): Number of processes for parallel processing.
|
||||
subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
|
||||
split (str, optional): Dataset split to use (e.g., "test").
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processes: int = 1,
|
||||
subset: Literal[
|
||||
"covidqa",
|
||||
"cuad",
|
||||
"delucionqa",
|
||||
"emanual",
|
||||
"expertqa",
|
||||
"finqa",
|
||||
"hagrid",
|
||||
"hotpotqa",
|
||||
"msmarco",
|
||||
"pubmedqa",
|
||||
"tatqa",
|
||||
"techqa",
|
||||
] = "hotpotqa",
|
||||
split: Literal["train", "test", "validation"] = "test",
|
||||
) -> None:
|
||||
super().__init__("ragbench", "rag_bench", "", processes)
|
||||
self.subset = subset
|
||||
self.split = split
|
||||
self.dataset: Optional[Dataset] = None
|
||||
|
||||
def download(self):
|
||||
r"""Download the RAGBench dataset."""
|
||||
try:
|
||||
self.dataset = load_dataset(
|
||||
"rungalileo/ragbench", self.subset, split=self.split
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download dataset: {e}")
|
||||
raise
|
||||
|
||||
def load(self, force_download: bool = False):
|
||||
r"""Load the RAGBench dataset.
|
||||
|
||||
Args:
|
||||
force_download (bool, optional): Whether to force download the
|
||||
data.
|
||||
"""
|
||||
if force_download or self.dataset is None:
|
||||
logger.info(
|
||||
"%s dataset",
|
||||
"Force downloading" if force_download else "Loading",
|
||||
)
|
||||
self.download()
|
||||
|
||||
def run( # type: ignore[override, return]
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
auto_retriever: AutoRetriever,
|
||||
) -> Dict[str, Optional[float]]:
|
||||
r"""Run the benchmark evaluation.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): Chat agent for generating answers.
|
||||
auto_retriever (AutoRetriever): Retriever for finding relevant
|
||||
contexts.
|
||||
|
||||
Returns:
|
||||
Dict[str, Optional[float]]: Dictionary of evaluation metrics.
|
||||
"""
|
||||
|
||||
def context_call(example):
|
||||
retrieved_info = auto_retriever.run_vector_retriever(
|
||||
query=example['question'],
|
||||
contents=example['documents'],
|
||||
top_k=1,
|
||||
return_detailed_info=True,
|
||||
similarity_threshold=0.5,
|
||||
)
|
||||
return [c['text'] for c in retrieved_info['Retrieved Context']]
|
||||
|
||||
def answer_call(example: Dict[str, Any]) -> str:
|
||||
user_msg = str(example)
|
||||
assistant_response = agent.step(user_msg)
|
||||
return assistant_response.msg.content
|
||||
|
||||
# Annotate the dataset
|
||||
annotated_ds = annotate_dataset(
|
||||
self.dataset, context_call, answer_call
|
||||
)
|
||||
evaluated_ds = ragas_evaluate_dataset(
|
||||
annotated_ds,
|
||||
contexts_field_name="contexts",
|
||||
answer_field_name="answer",
|
||||
metrics_to_evaluate=["context_relevancy", "faithfulness"],
|
||||
)
|
||||
|
||||
return ragas_calculate_metrics(
|
||||
evaluated_ds,
|
||||
pred_context_relevance_field="context_relevancy",
|
||||
pred_faithfulness_field="faithfulness",
|
||||
)
|
||||
34
camel/bots/__init__.py
Normal file
34
camel/bots/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .discord import DiscordApp
|
||||
from .slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack.slack_app import SlackApp
|
||||
from .telegram_bot import TelegramBot
|
||||
|
||||
__all__ = [
|
||||
'DiscordApp',
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
'TelegramBot',
|
||||
]
|
||||
26
camel/bots/discord/__init__.py
Normal file
26
camel/bots/discord/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .discord_app import DiscordApp
|
||||
from .discord_installation import DiscordInstallation
|
||||
from .discord_store import (
|
||||
DiscordBaseInstallationStore,
|
||||
DiscordSQLiteInstallationStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DiscordApp",
|
||||
"DiscordInstallation",
|
||||
"DiscordSQLiteInstallationStore",
|
||||
"DiscordBaseInstallationStore",
|
||||
]
|
||||
384
camel/bots/discord/discord_app.py
Normal file
384
camel/bots/discord/discord_app.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import discord
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
|
||||
from camel.bots.discord.discord_installation import DiscordInstallation
|
||||
from camel.logger import get_logger
|
||||
from camel.utils import api_keys_required, dependencies_required
|
||||
|
||||
from .discord_store import DiscordBaseInstallationStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import Message
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TOKEN_URL = "https://discord.com/api/oauth2/token"
|
||||
USER_URL = "https://discord.com/api/users/@me"
|
||||
|
||||
|
||||
class DiscordApp:
|
||||
r"""A class representing a Discord app that uses the `discord.py` library
|
||||
to interact with Discord servers.
|
||||
|
||||
This bot can respond to messages in specific channels and only reacts to
|
||||
messages that mention the bot.
|
||||
|
||||
Attributes:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
|
||||
provided, the bot will only respond to messages in these channels.
|
||||
token (Optional[str]): The Discord bot token used for authentication.
|
||||
"""
|
||||
|
||||
@dependencies_required('discord')
|
||||
@api_keys_required(
|
||||
[
|
||||
("token", "DISCORD_BOT_TOKEN"),
|
||||
]
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
channel_ids: Optional[List[int]] = None,
|
||||
token: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
installation_store: Optional[DiscordBaseInstallationStore] = None,
|
||||
intents: Optional[discord.Intents] = None,
|
||||
) -> None:
|
||||
r"""Initialize the DiscordApp instance by setting up the Discord client
|
||||
and event handlers.
|
||||
|
||||
Args:
|
||||
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
|
||||
The bot will only respond to messages in these channels if
|
||||
provided. (default: :obj:`None`)
|
||||
token (Optional[str]): The Discord bot token for authentication.
|
||||
If not provided, the token will be retrieved from the
|
||||
environment variable `DISCORD_TOKEN`. (default: :obj:`None`)
|
||||
client_id (str, optional): The client ID for Discord OAuth.
|
||||
(default: :obj:`None`)
|
||||
client_secret (Optional[str]): The client secret for Discord OAuth.
|
||||
(default: :obj:`None`)
|
||||
redirect_uri (str): The redirect URI for OAuth callbacks.
|
||||
(default: :obj:`None`)
|
||||
installation_store (DiscordAsyncInstallationStore): The database
|
||||
stores all information of all installations.
|
||||
(default: :obj:`None`)
|
||||
intents (discord.Intents): The Discord intents of this app.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Raises:
|
||||
ValueError: If the `DISCORD_BOT_TOKEN` is not found in environment
|
||||
variables.
|
||||
"""
|
||||
self.token = token or os.getenv("DISCORD_BOT_TOKEN")
|
||||
self.channel_ids = channel_ids
|
||||
self.installation_store = installation_store
|
||||
|
||||
if not intents:
|
||||
intents = discord.Intents.all()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client = discord.Client(intents=intents)
|
||||
|
||||
# Register event handlers
|
||||
self._client.event(self.on_ready)
|
||||
self._client.event(self.on_message)
|
||||
|
||||
# OAuth flow
|
||||
self.client_id = client_id or os.getenv("DISCORD_CLIENT_ID")
|
||||
self.client_secret = client_secret or os.getenv(
|
||||
"DISCORD_CLIENT_SECRET"
|
||||
)
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
self.oauth_flow = bool(
|
||||
self.client_id
|
||||
and self.client_secret
|
||||
and self.redirect_uri
|
||||
and self.installation_store
|
||||
)
|
||||
|
||||
self.app = FastAPI()
|
||||
|
||||
async def start(self):
|
||||
r"""Asynchronously start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord asynchronously using
|
||||
the provided token. It should be awaited when used in an async
|
||||
environment.
|
||||
"""
|
||||
await self._client.start(self.token)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Discord bot using its token.
|
||||
|
||||
This method starts the bot and logs into Discord synchronously using
|
||||
the provided token. It blocks execution and keeps the bot running.
|
||||
"""
|
||||
self._client.run(self.token) # type: ignore[arg-type]
|
||||
|
||||
async def exchange_code_for_token_response(
|
||||
self, code: str
|
||||
) -> Optional[str]:
|
||||
r"""Exchange the authorization code for an access token.
|
||||
|
||||
Args:
|
||||
code (str): The authorization code received from Discord after
|
||||
user authorization.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The access token if successful, otherwise None.
|
||||
|
||||
Raises:
|
||||
ValueError: If OAuth configuration is incomplete or invalid.
|
||||
httpx.RequestError: If there is a network issue during the request.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
TOKEN_URL, data=data, headers=headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to exchange code: {response.text}")
|
||||
return None
|
||||
response_data = response.json()
|
||||
|
||||
return response_data
|
||||
except (httpx.RequestError, ValueError) as e:
|
||||
logger.error(f"Error during token fetch: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_info(self, access_token: str) -> Optional[dict]:
|
||||
r"""Retrieve user information using the access token.
|
||||
|
||||
Args:
|
||||
access_token (str): The access token received from Discord.
|
||||
|
||||
Returns:
|
||||
dict: The user information retrieved from Discord.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(USER_URL, headers=headers)
|
||||
return user_response.json()
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> Optional[str]:
|
||||
r"""Refresh the access token using a refresh token.
|
||||
|
||||
Args:
|
||||
refresh_token (str): The refresh token issued by Discord that
|
||||
can be used to obtain a new access token.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The new access token if successful, otherwise None.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(TOKEN_URL, data=data, headers=headers)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to refresh token: {response.text}")
|
||||
return None
|
||||
response_data = response.json()
|
||||
return response_data.get("access_token")
|
||||
|
||||
async def get_valid_access_token(self, guild_id: str) -> Optional[str]:
|
||||
r"""Retrieve a valid access token for the specified guild.
|
||||
|
||||
This method attempts to retrieve an access token for a specific guild.
|
||||
If the current access token is expired, it will refresh the token using
|
||||
the refresh token.
|
||||
|
||||
Args:
|
||||
guild_id (str): The ID of the guild to retrieve the access
|
||||
token for.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The valid access token if successful,
|
||||
otherwise None.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
assert self.installation_store is not None
|
||||
installation = await self.installation_store.find_by_guild(
|
||||
guild_id=guild_id
|
||||
)
|
||||
if not installation:
|
||||
logger.error(f"No installation found for guild: {guild_id}")
|
||||
return None
|
||||
|
||||
if (
|
||||
installation.token_expires_at
|
||||
and datetime.now() >= installation.token_expires_at
|
||||
):
|
||||
logger.info(
|
||||
f"Access token expired for guild: {guild_id}, "
|
||||
f"refreshing token..."
|
||||
)
|
||||
new_access_token = await self.refresh_access_token(
|
||||
installation.refresh_token
|
||||
)
|
||||
if new_access_token:
|
||||
installation.access_token = new_access_token
|
||||
installation.token_expires_at = datetime.now() + timedelta(
|
||||
seconds=3600
|
||||
)
|
||||
await self.installation_store.save(installation)
|
||||
return new_access_token
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to refresh access token for guild: {guild_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
return installation.access_token
|
||||
|
||||
async def save_installation(
|
||||
self,
|
||||
guild_id: str,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
):
|
||||
r"""Save the installation information for a given guild.
|
||||
|
||||
Args:
|
||||
guild_id (str): The ID of the guild where the bot is installed.
|
||||
access_token (str): The access token for the guild.
|
||||
refresh_token (str): The refresh token for the guild.
|
||||
expires_in: (int): The expiration time of the
|
||||
access token.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
assert self.installation_store is not None
|
||||
expires_at = datetime.now() + timedelta(seconds=expires_in)
|
||||
installation = DiscordInstallation(
|
||||
guild_id=guild_id,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
installed_at=datetime.now(),
|
||||
token_expires_at=expires_at,
|
||||
)
|
||||
await self.installation_store.save(installation)
|
||||
logger.info(f"Installation saved for guild: {guild_id}")
|
||||
|
||||
async def remove_installation(self, guild: discord.Guild):
|
||||
r"""Remove the installation for a given guild.
|
||||
|
||||
Args:
|
||||
guild (discord.Guild): The guild from which the bot is
|
||||
being removed.
|
||||
"""
|
||||
if not self.oauth_flow:
|
||||
logger.warning(
|
||||
"OAuth is not enabled. Missing client_id, "
|
||||
"client_secret, or redirect_uri."
|
||||
)
|
||||
return None
|
||||
assert self.installation_store is not None
|
||||
await self.installation_store.delete(guild_id=str(guild.id))
|
||||
print(f"Bot removed from guild: {guild.id}")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
r"""Event handler that is called when the bot has successfully
|
||||
connected to the Discord server.
|
||||
|
||||
When the bot is ready and logged into Discord, it prints a message
|
||||
displaying the bot's username.
|
||||
"""
|
||||
logger.info(f'We have logged in as {self._client.user}')
|
||||
|
||||
async def on_message(self, message: 'Message') -> None:
|
||||
r"""Event handler for processing incoming messages.
|
||||
|
||||
This method is called whenever a new message is received by the bot. It
|
||||
will ignore messages sent by the bot itself, only respond to messages
|
||||
in allowed channels (if specified), and only to messages that mention
|
||||
the bot.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message object received from
|
||||
Discord.
|
||||
"""
|
||||
# If the message author is the bot itself,
|
||||
# do not respond to this message
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# If allowed channel IDs are provided,
|
||||
# only respond to messages in those channels
|
||||
if self.channel_ids and message.channel.id not in self.channel_ids:
|
||||
return
|
||||
|
||||
# Only respond to messages that mention the bot
|
||||
if not self._client.user or not self._client.user.mentioned_in(
|
||||
message
|
||||
):
|
||||
return
|
||||
|
||||
logger.info(f"Received message: {message.content}")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
64
camel/bots/discord/discord_installation.py
Normal file
64
camel/bots/discord/discord_installation.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DiscordInstallation:
|
||||
r"""Represents an installation of a Discord application in a
|
||||
specific guild (server).
|
||||
|
||||
Attributes:
|
||||
guild_id (str): The unique identifier for the Discord guild (server)
|
||||
where the application is installed.
|
||||
access_token (str): The access token used to authenticate API requests
|
||||
for the installed application.
|
||||
refresh_token (str): The token used to refresh the access token when
|
||||
it expires.
|
||||
installed_at (datetime): The timestamp indicating when the application
|
||||
was installed in the guild.
|
||||
token_expires_at (Optional[datetime]): The optional timestamp
|
||||
indicating when the access token will expire. Defaults to None
|
||||
if the token does not have an expiration time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guild_id: str,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
installed_at: datetime,
|
||||
token_expires_at: Optional[datetime] = None,
|
||||
):
|
||||
r"""Initialize the DiscordInstallation.
|
||||
|
||||
Args:
|
||||
guild_id (str): The unique identifier for the Discord guild
|
||||
(server) where the application is installed.
|
||||
access_token (str): The access token used to authenticate API
|
||||
requests for the installed application.
|
||||
refresh_token (str): The token used to refresh the access token
|
||||
when it expires.
|
||||
installed_at (datetime): The timestamp indicating when the
|
||||
application was installed in the guild.
|
||||
token_expires_at (Optional[datetime]): The optional timestamp
|
||||
indicating when the access token will expire. Defaults to None
|
||||
if the token does not have an expiration time.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
self.guild_id = guild_id
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
self.installed_at = installed_at
|
||||
self.token_expires_at = token_expires_at
|
||||
160
camel/bots/discord/discord_store.py
Normal file
160
camel/bots/discord/discord_store.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .discord_installation import DiscordInstallation
|
||||
|
||||
|
||||
class DiscordBaseInstallationStore:
|
||||
r"""Abstract base class for managing Discord installations.
|
||||
|
||||
This class defines the interface for database operations related to storing
|
||||
and retrieving Discord installation data. Subclasses must implement these
|
||||
methods to handle database-specific logic.
|
||||
"""
|
||||
|
||||
async def init(self):
|
||||
r"""Initializes the database connection or structure."""
|
||||
pass
|
||||
|
||||
async def save(self, installation: DiscordInstallation):
|
||||
r"""Saves or updates a Discord installation record."""
|
||||
pass
|
||||
|
||||
async def find_by_guild(
|
||||
self, guild_id: str
|
||||
) -> Optional[DiscordInstallation]:
|
||||
r"""Finds an installation record by guild ID."""
|
||||
pass
|
||||
|
||||
async def delete(self, guild_id: str):
|
||||
r"""Deletes an installation record by guild ID."""
|
||||
pass
|
||||
|
||||
|
||||
class DiscordSQLiteInstallationStore(DiscordBaseInstallationStore):
|
||||
r"""SQLite-based implementation for managing Discord installations.
|
||||
|
||||
This class provides methods for initializing the database, saving,
|
||||
retrieving, and deleting installation records using SQLite.
|
||||
|
||||
Attributes:
|
||||
database (str): Path to the SQLite database file.
|
||||
"""
|
||||
|
||||
def __init__(self, database: str):
|
||||
r"""Initializes the SQLite installation store.
|
||||
|
||||
Args:
|
||||
database (str): Path to the SQLite database file.
|
||||
"""
|
||||
self.database = database
|
||||
|
||||
async def init(self):
|
||||
r"""Initializes the database by creating the required table if it
|
||||
does not exist."""
|
||||
import aiosqlite
|
||||
|
||||
async with aiosqlite.connect(self.database) as db:
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS discord_installations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
guild_id TEXT NOT NULL UNIQUE,
|
||||
access_token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
installed_at DATETIME NOT NULL,
|
||||
token_expires_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def save(self, installation: DiscordInstallation):
|
||||
r"""Saves a new installation record or updates an existing one.
|
||||
|
||||
Args:
|
||||
installation (DiscordInstallation): The installation data to save.
|
||||
"""
|
||||
import aiosqlite
|
||||
|
||||
async with aiosqlite.connect(self.database) as db:
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO discord_installations (
|
||||
guild_id, access_token, refresh_token,
|
||||
installed_at, token_expires_at
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(guild_id) DO UPDATE SET
|
||||
access_token = excluded.access_token,
|
||||
refresh_token = excluded.refresh_token,
|
||||
token_expires_at = excluded.token_expires_at;
|
||||
""",
|
||||
[
|
||||
installation.guild_id,
|
||||
installation.access_token,
|
||||
installation.refresh_token,
|
||||
installation.installed_at,
|
||||
installation.token_expires_at,
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def find_by_guild(
|
||||
self, guild_id: str
|
||||
) -> Optional[DiscordInstallation]:
|
||||
r"""Finds an installation record by guild ID.
|
||||
|
||||
Args:
|
||||
guild_id (str): The guild ID to search for.
|
||||
|
||||
Returns:
|
||||
Optional[DiscordInstallation]: The installation record if found,
|
||||
otherwise None.
|
||||
"""
|
||||
import aiosqlite
|
||||
|
||||
async with aiosqlite.connect(self.database) as db:
|
||||
async with db.execute(
|
||||
"SELECT guild_id, access_token, refresh_token, "
|
||||
"installed_at, token_expires_at FROM discord_installations "
|
||||
"WHERE guild_id = ?",
|
||||
[guild_id],
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return DiscordInstallation(
|
||||
guild_id=row[0],
|
||||
access_token=row[1],
|
||||
refresh_token=row[2],
|
||||
installed_at=row[3],
|
||||
token_expires_at=row[4],
|
||||
)
|
||||
return None
|
||||
|
||||
async def delete(self, guild_id: str):
|
||||
r"""Deletes an installation record by guild ID.
|
||||
|
||||
Args:
|
||||
guild_id (str): The guild ID of the record to delete.
|
||||
"""
|
||||
import aiosqlite
|
||||
|
||||
async with aiosqlite.connect(self.database) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM discord_installations WHERE guild_id = ?",
|
||||
[guild_id],
|
||||
)
|
||||
await db.commit()
|
||||
30
camel/bots/slack/__init__.py
Normal file
30
camel/bots/slack/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackAuthProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from .slack_app import SlackApp
|
||||
|
||||
__all__ = [
|
||||
'SlackApp',
|
||||
'SlackAppMentionEventBody',
|
||||
'SlackAppMentionEventProfile',
|
||||
'SlackAuthProfile',
|
||||
'SlackEventBody',
|
||||
'SlackEventProfile',
|
||||
]
|
||||
158
camel/bots/slack/models.py
Normal file
158
camel/bots/slack/models.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlackAuthProfile(BaseModel):
|
||||
r"""Represents the authorization profile within a Slack event.
|
||||
|
||||
Events will contain a single, compact authorizations field that shows one
|
||||
installation of your app that the event is visible to.
|
||||
In other words, lists of authorizations will be truncated to one element.
|
||||
|
||||
If there's more than one installing party that your app is keeping track
|
||||
of, it's best not to rely on the single party listed in authorizations to
|
||||
be any particular one.
|
||||
|
||||
To get a full list of who can see events, call the apps.event.
|
||||
authorizations.list method after obtaining an app-level token. Read more on
|
||||
the changes here; they have taken effect for existing apps as of
|
||||
February 24, 2021.
|
||||
|
||||
References:
|
||||
|
||||
- https://api.slack.com/apis/events-api#authorizations
|
||||
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
|
||||
"""
|
||||
|
||||
enterprise_id: Optional[str] = None
|
||||
"""The ID of the enterprise associated with the authorization."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team associated with the authorization."""
|
||||
|
||||
user_id: str
|
||||
"""The ID of the user associated with the authorization."""
|
||||
|
||||
is_bot: bool
|
||||
"""Whether the authorized user is a bot."""
|
||||
|
||||
is_enterprise_install: bool
|
||||
"""Whether the authorization is for an enterprise installation."""
|
||||
|
||||
|
||||
class SlackEventProfile(BaseModel):
|
||||
r"""Represents the detailed profile of a Slack event, including user,
|
||||
message, and context data.
|
||||
"""
|
||||
|
||||
user: str
|
||||
"""The ID of the user associated with the event."""
|
||||
|
||||
type: str
|
||||
"""The type of the event (e.g., 'message')."""
|
||||
|
||||
ts: str
|
||||
"""A timestamp representing when the event was triggered."""
|
||||
|
||||
thread_ts: Optional[str] = None
|
||||
"""The timestamp of the parent message in a thread."""
|
||||
|
||||
client_msg_id: str
|
||||
"""A unique ID generated by the client for the message (if available)."""
|
||||
|
||||
text: str
|
||||
"""The message content text."""
|
||||
|
||||
team: str
|
||||
"""The ID of the team that the event is associated with."""
|
||||
|
||||
blocks: list
|
||||
"""The list of message blocks, providing structured information."""
|
||||
|
||||
channel: str
|
||||
"""The ID of the Slack channel where the event happened."""
|
||||
|
||||
event_ts: str
|
||||
"""The event-specific timestamp when it occurred."""
|
||||
|
||||
channel_type: Optional[str]
|
||||
"""The type of Slack channel (e.g., 'channel', 'im')."""
|
||||
|
||||
|
||||
class SlackEventBody(BaseModel):
|
||||
r"""Represents the entire body of a Slack event, including the event
|
||||
profile, authorization, and context.
|
||||
"""
|
||||
|
||||
token: str
|
||||
"""The token to verify the source of the event."""
|
||||
|
||||
team_id: str
|
||||
"""The ID of the team where the event is happening."""
|
||||
|
||||
context_team_id: Optional[str]
|
||||
"""The team ID for the shared channel context, if applicable."""
|
||||
|
||||
context_enterprise_id: Optional[str] = None
|
||||
"""The enterprise ID for the shared channel context, if applicable."""
|
||||
|
||||
api_app_id: str
|
||||
"""The unique identifier for the Slack app that received the event."""
|
||||
|
||||
event: SlackEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
|
||||
type: str
|
||||
"""The overall type of event received (e.g., 'event_callback')."""
|
||||
|
||||
event_id: str
|
||||
"""A unique identifier assigned to this event by Slack."""
|
||||
|
||||
event_time: int
|
||||
"""The timestamp (in seconds) representing when the event was triggered."""
|
||||
|
||||
authorizations: Optional[list[SlackAuthProfile]] = None
|
||||
"""An optional list of authorizations that describe which installation can
|
||||
see the event."""
|
||||
|
||||
is_ext_shared_channel: bool
|
||||
"""Indicates if the event is part of a shared channel between different
|
||||
organizations."""
|
||||
|
||||
event_context: str
|
||||
"""A unique string representing the context of the event."""
|
||||
|
||||
|
||||
class SlackAppMentionEventProfile(SlackEventProfile):
|
||||
r"""Represents the detailed profile of a Slack event where the app was
|
||||
mentioned in a message.
|
||||
"""
|
||||
|
||||
channel_type: Optional[str] = None
|
||||
"""The type of Slack channel. it's None for app mentions."""
|
||||
|
||||
|
||||
class SlackAppMentionEventBody(SlackEventBody):
|
||||
r"""Represents the entire body of a Slack event where the app was mentioned
|
||||
in a message.
|
||||
"""
|
||||
|
||||
context_team_id: Optional[str] = None
|
||||
"""A detailed profile of the event. it's None for app mentions."""
|
||||
|
||||
event: SlackAppMentionEventProfile
|
||||
"""A detailed profile of the event"""
|
||||
255
camel/bots/slack/slack_app.py
Normal file
255
camel/bots/slack/slack_app.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from slack_sdk.oauth.installation_store.async_installation_store import (
|
||||
AsyncInstallationStore,
|
||||
)
|
||||
from starlette import requests, responses
|
||||
|
||||
from camel.bots.slack.models import (
|
||||
SlackAppMentionEventBody,
|
||||
SlackAppMentionEventProfile,
|
||||
SlackEventBody,
|
||||
SlackEventProfile,
|
||||
)
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from slack_bolt.context.async_context import AsyncBoltContext
|
||||
from slack_bolt.context.say.async_say import AsyncSay
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackApp:
|
||||
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
|
||||
|
||||
This class is responsible for initializing and managing the Slack
|
||||
application by setting up event handlers, running the app server, and
|
||||
handling events such as messages and mentions from Slack.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): Slack API token for authentication.
|
||||
scopes (Optional[str]): Slack app scopes for permissions.
|
||||
signing_secret (Optional[str]): Signing secret for verifying Slack
|
||||
requests.
|
||||
client_id (Optional[str]): Slack app client ID.
|
||||
client_secret (Optional[str]): Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
|
||||
"/slack/oauth_redirect".
|
||||
installation_store (Optional[AsyncInstallationStore]): The installation
|
||||
store for handling OAuth installations.
|
||||
"""
|
||||
|
||||
@dependencies_required('slack_bolt')
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
scopes: Optional[str] = None,
|
||||
signing_secret: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
redirect_uri_path: str = "/slack/oauth_redirect",
|
||||
installation_store: Optional[AsyncInstallationStore] = None,
|
||||
) -> None:
|
||||
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
|
||||
and configuring event handlers and OAuth settings.
|
||||
|
||||
Args:
|
||||
token (Optional[str]): The Slack API token.
|
||||
scopes (Optional[str]): The scopes for Slack app permissions.
|
||||
signing_secret (Optional[str]): The signing secret for verifying
|
||||
requests.
|
||||
client_id (Optional[str]): The Slack app client ID.
|
||||
client_secret (Optional[str]): The Slack app client secret.
|
||||
redirect_uri_path (str): The URI path for handling OAuth redirects
|
||||
(default is "/slack/oauth_redirect").
|
||||
installation_store (Optional[AsyncInstallationStore]): An optional
|
||||
installation store for OAuth installations.
|
||||
"""
|
||||
from slack_bolt.adapter.starlette.async_handler import (
|
||||
AsyncSlackRequestHandler,
|
||||
)
|
||||
from slack_bolt.app.async_app import AsyncApp
|
||||
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
|
||||
|
||||
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
|
||||
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
|
||||
self.signing_secret: Optional[str] = signing_secret or os.getenv(
|
||||
"SLACK_SIGNING_SECRET"
|
||||
)
|
||||
self.client_id: Optional[str] = client_id or os.getenv(
|
||||
"SLACK_CLIENT_ID"
|
||||
)
|
||||
self.client_secret: Optional[str] = client_secret or os.getenv(
|
||||
"SLACK_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
if not all([self.token, self.scopes, self.signing_secret]):
|
||||
raise ValueError(
|
||||
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
|
||||
"environment variables must be set. Get it here: "
|
||||
"`https://api.slack.com/apps`."
|
||||
)
|
||||
|
||||
# Setup OAuth settings if client ID and secret are provided
|
||||
if self.client_id and self.client_secret:
|
||||
self._app = AsyncApp(
|
||||
oauth_settings=AsyncOAuthSettings(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
scopes=self.scopes,
|
||||
redirect_uri_path=redirect_uri_path,
|
||||
),
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
else:
|
||||
# Initialize Slack Bolt AsyncApp with settings
|
||||
self._app = AsyncApp(
|
||||
logger=logger,
|
||||
signing_secret=self.signing_secret,
|
||||
installation_store=installation_store,
|
||||
token=self.token,
|
||||
)
|
||||
|
||||
self._handler = AsyncSlackRequestHandler(self._app)
|
||||
self.setup_handlers()
|
||||
|
||||
def setup_handlers(self) -> None:
|
||||
r"""Sets up the event handlers for Slack events, such as `app_mention`
|
||||
and `message`.
|
||||
|
||||
This method registers the `app_mention` and `on_message` event handlers
|
||||
with the Slack Bolt app to respond to Slack events.
|
||||
"""
|
||||
self._app.event("app_mention")(self.app_mention)
|
||||
self._app.event("message")(self.on_message)
|
||||
|
||||
def run(
|
||||
self,
|
||||
port: int = 3000,
|
||||
path: str = "/slack/events",
|
||||
host: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Starts the Slack Bolt app server to listen for incoming Slack
|
||||
events.
|
||||
|
||||
Args:
|
||||
port (int): The port on which the server should run (default is
|
||||
3000).
|
||||
path (str): The endpoint path for receiving Slack events (default
|
||||
is "/slack/events").
|
||||
host (Optional[str]): The hostname to bind the server (default is
|
||||
None).
|
||||
"""
|
||||
self._app.start(port=port, path=path, host=host)
|
||||
|
||||
async def handle_request(
|
||||
self, request: requests.Request
|
||||
) -> responses.Response:
|
||||
r"""Handles incoming requests from Slack through the request handler.
|
||||
|
||||
Args:
|
||||
request (Request): A Starlette request object representing the
|
||||
incoming request.
|
||||
|
||||
Returns:
|
||||
The response generated by the Slack Bolt handler.
|
||||
"""
|
||||
return await self._handler.handle(request)
|
||||
|
||||
async def app_mention(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `app_mention` events.
|
||||
|
||||
This method is triggered when someone mentions the app in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the app mention.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
event_profile = SlackAppMentionEventProfile(**event)
|
||||
event_body = SlackAppMentionEventBody(**body)
|
||||
|
||||
logger.info(f"app_mention, context: {context}")
|
||||
logger.info(f"app_mention, client: {client}")
|
||||
logger.info(f"app_mention, event_profile: {event_profile}")
|
||||
logger.info(f"app_mention, event_body: {event_body}")
|
||||
logger.info(f"app_mention, say: {say}")
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
context: "AsyncBoltContext",
|
||||
client: "AsyncWebClient",
|
||||
event: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
say: "AsyncSay",
|
||||
) -> None:
|
||||
r"""Event handler for `message` events.
|
||||
|
||||
This method is triggered when the app receives a message in Slack.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
client (AsyncWebClient): The Slack Web API client.
|
||||
event (Dict[str, Any]): The event data for the message.
|
||||
body (Dict[str, Any]): The full request body from Slack.
|
||||
say (AsyncSay): A function to send a response back to the channel.
|
||||
"""
|
||||
await context.ack()
|
||||
|
||||
event_profile = SlackEventProfile(**event)
|
||||
event_body = SlackEventBody(**body)
|
||||
|
||||
logger.info(f"on_message, context: {context}")
|
||||
logger.info(f"on_message, client: {client}")
|
||||
logger.info(f"on_message, event_profile: {event_profile}")
|
||||
logger.info(f"on_message, event_body: {event_body}")
|
||||
logger.info(f"on_message, say: {say}")
|
||||
|
||||
logger.info(f"Received message: {event_profile.text}")
|
||||
|
||||
def mention_me(
|
||||
self, context: "AsyncBoltContext", body: SlackEventBody
|
||||
) -> bool:
|
||||
r"""Check if the bot is mentioned in the message.
|
||||
|
||||
Args:
|
||||
context (AsyncBoltContext): The Slack Bolt context for the event.
|
||||
body (SlackEventBody): The body of the Slack event.
|
||||
|
||||
Returns:
|
||||
bool: True if the bot is mentioned in the message, False otherwise.
|
||||
"""
|
||||
message = body.event.text
|
||||
bot_user_id = context.bot_user_id
|
||||
mention = f"<@{bot_user_id}>"
|
||||
return mention in message
|
||||
78
camel/bots/telegram_bot.py
Normal file
78
camel/bots/telegram_bot.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.utils import dependencies_required
|
||||
|
||||
# Conditionally import telebot types only for type checking
|
||||
if TYPE_CHECKING:
|
||||
from telebot.types import ( # type: ignore[import-untyped]
|
||||
Message,
|
||||
)
|
||||
|
||||
|
||||
class TelegramBot:
|
||||
r"""Represents a Telegram bot that is powered by an agent.
|
||||
|
||||
Attributes:
|
||||
chat_agent (ChatAgent): Chat agent that will power the bot.
|
||||
telegram_token (str, optional): The bot token.
|
||||
"""
|
||||
|
||||
@dependencies_required('telebot')
|
||||
def __init__(
|
||||
self,
|
||||
chat_agent: ChatAgent,
|
||||
telegram_token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.chat_agent = chat_agent
|
||||
|
||||
if not telegram_token:
|
||||
self.token = os.getenv('TELEGRAM_TOKEN')
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
"`TELEGRAM_TOKEN` not found in environment variables. "
|
||||
"Get it from t.me/BotFather."
|
||||
)
|
||||
else:
|
||||
self.token = telegram_token
|
||||
|
||||
import telebot # type: ignore[import-untyped]
|
||||
|
||||
self.bot = telebot.TeleBot(token=self.token)
|
||||
|
||||
# Register the message handler within the constructor
|
||||
self.bot.message_handler(func=lambda message: True)(self.on_message)
|
||||
|
||||
def run(self) -> None:
|
||||
r"""Start the Telegram bot."""
|
||||
print("Telegram bot is running...")
|
||||
self.bot.infinity_polling()
|
||||
|
||||
def on_message(self, message: 'Message') -> None:
|
||||
r"""Handles incoming messages from the user.
|
||||
|
||||
Args:
|
||||
message (types.Message): The incoming message object.
|
||||
"""
|
||||
self.chat_agent.reset()
|
||||
|
||||
if not message.text:
|
||||
return
|
||||
|
||||
assistant_response = self.chat_agent.step(message.text)
|
||||
|
||||
self.bot.reply_to(message, assistant_response.msg.content)
|
||||
106
camel/configs/__init__.py
Normal file
106
camel/configs/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .aiml_config import AIML_API_PARAMS, AIMLConfig
|
||||
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
|
||||
from .base_config import BaseConfig
|
||||
from .bedrock_config import BEDROCK_API_PARAMS, BedrockConfig
|
||||
from .cohere_config import COHERE_API_PARAMS, CohereConfig
|
||||
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
||||
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
|
||||
from .groq_config import GROQ_API_PARAMS, GroqConfig
|
||||
from .internlm_config import INTERNLM_API_PARAMS, InternLMConfig
|
||||
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
|
||||
from .lmstudio_config import LMSTUDIO_API_PARAMS, LMStudioConfig
|
||||
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
|
||||
from .modelscope_config import MODELSCOPE_API_PARAMS, ModelScopeConfig
|
||||
from .moonshot_config import MOONSHOT_API_PARAMS, MoonshotConfig
|
||||
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
|
||||
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
|
||||
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
|
||||
from .openrouter_config import OPENROUTER_API_PARAMS, OpenRouterConfig
|
||||
from .ppio_config import PPIO_API_PARAMS, PPIOConfig
|
||||
from .qwen_config import QWEN_API_PARAMS, QwenConfig
|
||||
from .reka_config import REKA_API_PARAMS, RekaConfig
|
||||
from .samba_config import (
|
||||
SAMBA_CLOUD_API_PARAMS,
|
||||
SAMBA_VERSE_API_PARAMS,
|
||||
SambaCloudAPIConfig,
|
||||
SambaVerseAPIConfig,
|
||||
)
|
||||
from .sglang_config import SGLANG_API_PARAMS, SGLangConfig
|
||||
from .siliconflow_config import SILICONFLOW_API_PARAMS, SiliconFlowConfig
|
||||
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
|
||||
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
|
||||
from .yi_config import YI_API_PARAMS, YiConfig
|
||||
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'ChatGPTConfig',
|
||||
'OPENAI_API_PARAMS',
|
||||
'AnthropicConfig',
|
||||
'ANTHROPIC_API_PARAMS',
|
||||
'GROQ_API_PARAMS',
|
||||
'GroqConfig',
|
||||
'LiteLLMConfig',
|
||||
'LITELLM_API_PARAMS',
|
||||
'NvidiaConfig',
|
||||
'NVIDIA_API_PARAMS',
|
||||
'OllamaConfig',
|
||||
'OLLAMA_API_PARAMS',
|
||||
'ZhipuAIConfig',
|
||||
'ZHIPUAI_API_PARAMS',
|
||||
'GeminiConfig',
|
||||
'Gemini_API_PARAMS',
|
||||
'VLLMConfig',
|
||||
'VLLM_API_PARAMS',
|
||||
'SGLangConfig',
|
||||
'SGLANG_API_PARAMS',
|
||||
'MistralConfig',
|
||||
'MISTRAL_API_PARAMS',
|
||||
'RekaConfig',
|
||||
'REKA_API_PARAMS',
|
||||
'SambaVerseAPIConfig',
|
||||
'SAMBA_VERSE_API_PARAMS',
|
||||
'SambaCloudAPIConfig',
|
||||
'SAMBA_CLOUD_API_PARAMS',
|
||||
'TogetherAIConfig',
|
||||
'TOGETHERAI_API_PARAMS',
|
||||
'CohereConfig',
|
||||
'COHERE_API_PARAMS',
|
||||
'YiConfig',
|
||||
'YI_API_PARAMS',
|
||||
'QwenConfig',
|
||||
'QWEN_API_PARAMS',
|
||||
'BedrockConfig',
|
||||
'BEDROCK_API_PARAMS',
|
||||
'DeepSeekConfig',
|
||||
'DEEPSEEK_API_PARAMS',
|
||||
'PPIOConfig',
|
||||
'PPIO_API_PARAMS',
|
||||
'InternLMConfig',
|
||||
'INTERNLM_API_PARAMS',
|
||||
'MoonshotConfig',
|
||||
"MOONSHOT_API_PARAMS",
|
||||
'ModelScopeConfig',
|
||||
'MODELSCOPE_API_PARAMS',
|
||||
'SiliconFlowConfig',
|
||||
'SILICONFLOW_API_PARAMS',
|
||||
'AIMLConfig',
|
||||
'AIML_API_PARAMS',
|
||||
'OpenRouterConfig',
|
||||
'OPENROUTER_API_PARAMS',
|
||||
'LMSTUDIO_API_PARAMS',
|
||||
'LMStudioConfig',
|
||||
]
|
||||
81
camel/configs/aiml_config.py
Normal file
81
camel/configs/aiml_config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NotGiven
|
||||
|
||||
|
||||
class AIMLConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
AIML API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Determines the degree of randomness
|
||||
in the response. (default: :obj:`None`)
|
||||
top_p (float, optional): The top_p (nucleus) parameter is used to
|
||||
dynamically adjust the number of choices for each predicted token
|
||||
based on the cumulative probabilities. (default: :obj:`None`)
|
||||
n (int, optional): Number of generations to return.
|
||||
(default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output.
|
||||
stream (bool, optional): If set, tokens are returned as Server-Sent
|
||||
Events as they are made available. (default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate.
|
||||
(default: :obj:`None`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str], NotGiven]] = None
|
||||
max_tokens: Optional[Union[int, NotGiven]] = None
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
response_format: Optional[Union[dict, NotGiven]] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
|
||||
|
||||
AIML_API_PARAMS = {param for param in AIMLConfig.model_fields.keys()}
|
||||
80
camel/configs/anthropic_config.py
Normal file
80
camel/configs/anthropic_config.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class AnthropicConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Anthropic API.
|
||||
|
||||
See: https://docs.anthropic.com/en/api/messages
|
||||
|
||||
Args:
|
||||
max_tokens (int, optional): The maximum number of tokens to
|
||||
generate before stopping. Note that Anthropic models may stop
|
||||
before reaching this maximum. This parameter only specifies the
|
||||
absolute maximum number of tokens to generate.
|
||||
(default: :obj:`None`)
|
||||
stop_sequences (List[str], optional): Custom text sequences that will
|
||||
cause the model to stop generating. The models will normally stop
|
||||
when they have naturally completed their turn. If the model
|
||||
encounters one of these custom sequences, the response will be
|
||||
terminated and the stop_reason will be "stop_sequence".
|
||||
(default: :obj:`None`)
|
||||
temperature (float, optional): Amount of randomness injected into the
|
||||
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
|
||||
for analytical / multiple choice, and closer to 1 for creative
|
||||
and generative tasks. Note that even with temperature of 0.0, the
|
||||
results will not be fully deterministic. (default: :obj:`None`)
|
||||
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
|
||||
compute the cumulative distribution over all the options for each
|
||||
subsequent token in decreasing probability order and cut it off
|
||||
once it reaches a particular probability specified by `top_p`.
|
||||
You should either alter `temperature` or `top_p`,
|
||||
but not both. (default: :obj:`None`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses. (default: :obj:`None`)
|
||||
stream (bool, optional): Whether to incrementally stream the response
|
||||
using server-sent events. (default: :obj:`None`)
|
||||
metadata (dict, optional): An object describing
|
||||
metadata about the request. Can include user_id as an external
|
||||
identifier for the user associated with the request.
|
||||
(default: :obj:`None`)
|
||||
thinking (dict, optional): Configuration for enabling
|
||||
Claude's extended thinking. When enabled, responses include
|
||||
thinking content blocks showing Claude's thinking process.
|
||||
(default: :obj:`None`)
|
||||
tool_choice (dict, optional): How the model should
|
||||
use the provided tools. The model can use a specific tool, any
|
||||
available tool, decide by itself, or not use tools at all.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
metadata: Optional[dict] = None
|
||||
thinking: Optional[dict] = None
|
||||
tool_choice: Optional[dict] = None
|
||||
|
||||
|
||||
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
|
||||
86
camel/configs/base_config.py
Normal file
86
camel/configs/base_config.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
class BaseConfig(ABC, BaseModel):
|
||||
r"""Base configuration class for all models.
|
||||
|
||||
This class provides a common interface for all models, ensuring that all
|
||||
models have a consistent set of attributes and methods.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
frozen=True,
|
||||
# UserWarning: conflict with protected namespace "model_"
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
tools: Optional[List[Any]] = None
|
||||
"""A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
"""
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, tools):
|
||||
r"""Validate the type of tools in the configuration.
|
||||
|
||||
This method ensures that the tools provided in the configuration are
|
||||
instances of `FunctionTool`. If any tool is not an instance of
|
||||
`FunctionTool`, it raises a ValueError.
|
||||
"""
|
||||
if tools is not None:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
for tool in tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
return tools
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
The dictionary won't contain None values, as some API does not support
|
||||
None values. (Like tool in OpenAI beta API)
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
|
||||
# Convert tools to OpenAI tool schema
|
||||
config_dict["tools"] = (
|
||||
[tool.get_openai_tool_schema() for tool in self.tools]
|
||||
if self.tools
|
||||
else None
|
||||
)
|
||||
|
||||
# Remove None values
|
||||
return {k: v for k, v in config_dict.items() if v is not None}
|
||||
73
camel/configs/bedrock_config.py
Normal file
73
camel/configs/bedrock_config.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class BedrockConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Args:
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
top_k (int, optional): The number of top tokens to consider.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
reasoning_effort(str, optional): A parameter specifying the level of
|
||||
reasoning used by certain model types. Valid values are :obj:
|
||||
`"low"`, :obj:`"medium"`, or :obj:`"high"`. If set, it is only
|
||||
applied to the model types that support it (e.g., :obj:`o1`,
|
||||
:obj:`o1mini`, :obj:`o1preview`, :obj:`o3mini`). If not provided
|
||||
or if the model type does not support it, this parameter is
|
||||
ignored. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
tool_choice: Optional[Union[Dict[str, str], str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
|
||||
|
||||
BEDROCK_API_PARAMS = {param for param in BedrockConfig.model_fields.keys()}
|
||||
77
camel/configs/cohere_config.py
Normal file
77
camel/configs/cohere_config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class CohereConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Cohere API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
documents (list, optional): A list of relevant documents that the
|
||||
model can cite to generate a more accurate reply. Each document is
|
||||
either a string or document object with content and metadata.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens the model
|
||||
will generate as part of the response. (default: :obj:`None`)
|
||||
stop_sequences (List(str), optional): A list of up to 5 strings that
|
||||
the model will use to stop generation. If the model generates a
|
||||
string that matches any of the strings in the list, it will stop
|
||||
generating tokens and return the generated text up to that point
|
||||
not including the stop sequence. (default: :obj:`None`)
|
||||
seed (int, optional): If specified, the backend will make a best
|
||||
effort to sample tokens deterministically, such that repeated
|
||||
requests with the same seed and parameters should return the same
|
||||
result. However, determinism cannot be totally guaranteed.
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. The
|
||||
higher the value, the stronger a penalty is applied to previously
|
||||
present tokens, proportional to how many times they have already
|
||||
appeared in the prompt or prior generation.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Min value of `0.0`, max value of
|
||||
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
|
||||
to `frequency_penalty`, except that this penalty is applied
|
||||
equally to all tokens that have already appeared, regardless of
|
||||
their exact frequencies. (default: :obj:`None`)
|
||||
k (int, optional): Ensures only the top k most likely tokens are
|
||||
considered for generation at each step. Min value of `0`, max
|
||||
value of `500`. (default: :obj:`None`)
|
||||
p (float, optional): Ensures that only the most likely tokens, with
|
||||
total probability mass of `p`, are considered for generation at
|
||||
each step. If both k and p are enabled, `p` acts after `k`. Min
|
||||
value of `0.01`, max value of `0.99`. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[float] = None
|
||||
|
||||
|
||||
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
|
||||
108
camel/configs/deepseek_config.py
Normal file
108
camel/configs/deepseek_config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class DeepSeekConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
DeepSeek API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`None`)
|
||||
response_format (object, optional): Specifies the format of the
|
||||
returned content. The available values are `{"type": "text"}` or
|
||||
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
||||
will output a standard JSON string.
|
||||
(default: :obj:`None`)
|
||||
stream (bool, optional): If set, partial message deltas will be sent.
|
||||
Tokens will be sent as data-only server-sent events (SSE) as
|
||||
they become available, with the stream terminated by a
|
||||
data: [DONE] message. (default: :obj:`None`)
|
||||
stop (Union[str, list[str]], optional): Up to 16 sequences where
|
||||
the API will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens that can
|
||||
be generated in the chat completion. The total length of input
|
||||
tokens and generated tokens is limited by the model's context
|
||||
length. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on whether they
|
||||
appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between -2.0 and 2.0.
|
||||
Positive values penalize new tokens based on their existing
|
||||
frequency in the text so far, decreasing the model's likelihood
|
||||
to repeat the same line verbatim. (default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use
|
||||
this to provide a list of functions the model may generate JSON
|
||||
inputs for. A max of 128 functions are supported.
|
||||
(default: :obj:`None`)
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which
|
||||
(if any) tool is called by the model. "none" means the model
|
||||
will not call any tool and instead generates a message. "auto"
|
||||
means the model can pick between generating a message or calling
|
||||
one or more tools. "required" means the model must call one or
|
||||
more tools. Specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "my_function"}} forces
|
||||
the model to call that tool. "none" is the default when no tools
|
||||
are present. "auto" is the default if tools are present.
|
||||
(default: :obj:`None`)
|
||||
logprobs (bool, optional): Whether to return log probabilities of
|
||||
the output tokens or not. If true, returns the log probabilities
|
||||
of each output token returned in the content of message.
|
||||
(default: :obj:`None`)
|
||||
top_logprobs (int, optional): An integer between 0 and 20 specifying
|
||||
the number of most likely tokens to return at each token
|
||||
position, each with an associated log probability. logprobs
|
||||
must be set to true if this parameter is used.
|
||||
(default: :obj:`None`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # deepseek default: 1.0
|
||||
top_p: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Union[Type[BaseModel], dict]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
|
||||
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
|
||||
88
camel/configs/gemini_config.py
Normal file
88
camel/configs/gemini_config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class GeminiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Gemini API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # openai default: 1.0
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
response_format: Optional[Union[Type[BaseModel], dict]] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}
|
||||
103
camel/configs/groq_config.py
Normal file
103
camel/configs/groq_config.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class GroqConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Reference: https://console.groq.com/docs/openai
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}
|
||||
60
camel/configs/internlm_config.py
Normal file
60
camel/configs/internlm_config.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class InternLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
InternLM API. You can refer to the following link for more details:
|
||||
https://internlm.intern-ai.org.cn/api/document
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`None`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`None`)
|
||||
top_p (float, optional): Controls the diversity and focus of the
|
||||
generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`None`)
|
||||
max_tokens (int, optional): Allows the model to
|
||||
generate the maximum number of tokens.
|
||||
(default: :obj:`None`)
|
||||
tools (list, optional): Specifies an array of tools that the model can
|
||||
call. It can contain one or more tool objects. During a function
|
||||
call process, the model will select one tool from the array.
|
||||
(default: :obj:`None`)
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
INTERNLM_API_PARAMS = {param for param in InternLMConfig.model_fields.keys()}
|
||||
99
camel/configs/litellm_config.py
Normal file
99
camel/configs/litellm_config.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class LiteLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
LiteLLM API.
|
||||
|
||||
Args:
|
||||
timeout (Optional[Union[float, str]], optional): Request timeout.
|
||||
(default: :obj:`None`)
|
||||
temperature (Optional[float], optional): Temperature parameter for
|
||||
controlling randomness. (default: :obj:`None`)
|
||||
top_p (Optional[float], optional): Top-p parameter for nucleus
|
||||
sampling. (default: :obj:`None`)
|
||||
n (Optional[int], optional): Number of completions to generate.
|
||||
(default: :obj:`None`)
|
||||
stream (Optional[bool], optional): Whether to return a streaming
|
||||
response. (default: :obj:`None`)
|
||||
stream_options (Optional[dict], optional): Options for the streaming
|
||||
response. (default: :obj:`None`)
|
||||
stop (Optional[Union[str, List[str]]], optional): Sequences where the
|
||||
API will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (Optional[int], optional): Maximum number of tokens to
|
||||
generate. (default: :obj:`None`)
|
||||
presence_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their existence in the text so far. (default: :obj:`None`)
|
||||
frequency_penalty (Optional[float], optional): Penalize new tokens
|
||||
based on their frequency in the text so far. (default: :obj:`None`)
|
||||
logit_bias (Optional[dict], optional): Modify the probability of
|
||||
specific tokens appearing in the completion. (default: :obj:`None`)
|
||||
user (Optional[str], optional): A unique identifier representing the
|
||||
end-user. (default: :obj:`None`)
|
||||
response_format (Optional[dict], optional): Response format
|
||||
parameters. (default: :obj:`None`)
|
||||
seed (Optional[int], optional): Random seed. (default: :obj:`None`)
|
||||
tools (Optional[List], optional): List of tools. (default: :obj:`None`)
|
||||
tool_choice (Optional[Union[str, dict]], optional): Tool choice
|
||||
parameters. (default: :obj:`None`)
|
||||
logprobs (Optional[bool], optional): Whether to return log
|
||||
probabilities of the output tokens. (default: :obj:`None`)
|
||||
top_logprobs (Optional[int], optional): Number of most likely tokens
|
||||
to return at each token position. (default: :obj:`None`)
|
||||
deployment_id (Optional[str], optional): Deployment ID.
|
||||
(default: :obj:`None`)
|
||||
extra_headers (Optional[dict], optional): Additional headers for the
|
||||
request. (default: :obj:`None`)
|
||||
api_version (Optional[str], optional): API version.
|
||||
(default: :obj:`None`)
|
||||
mock_response (Optional[str], optional): Mock completion response for
|
||||
testing or debugging. (default: :obj:`None`)
|
||||
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
|
||||
provider. (default: :obj:`None`)
|
||||
max_retries (Optional[int], optional): Maximum number of retries.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
timeout: Optional[Union[float, str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stream_options: Optional[dict] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
deployment_id: Optional[str] = None
|
||||
extra_headers: Optional[dict] = None
|
||||
api_version: Optional[str] = None
|
||||
mock_response: Optional[str] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
max_retries: Optional[int] = None
|
||||
|
||||
|
||||
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}
|
||||
94
camel/configs/lmstudio_config.py
Normal file
94
camel/configs/lmstudio_config.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class LMStudioConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
LMSTUDIO_API_PARAMS = {param for param in LMStudioConfig.model_fields.keys()}
|
||||
79
camel/configs/mistral_config.py
Normal file
79
camel/configs/mistral_config.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class MistralConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Mistral API.
|
||||
|
||||
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
|
||||
|
||||
#TODO: Support stream mode
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5. (default: :obj:`None`)
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. (default: :obj:`None`)
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. (default: :obj:`None`)
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list. (default: :obj:`None`)
|
||||
random_seed (Optional[int], optional): the random seed to use for
|
||||
sampling, e.g. 42. (default: :obj:`None`)
|
||||
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
|
||||
(default: :obj:`None`)
|
||||
response_format (Union[Dict[str, str], ResponseFormat): format of the
|
||||
response.
|
||||
tool_choice (str, optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"any"` means the
|
||||
model must call one or more tools. :obj:`"auto"` is the default
|
||||
value.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: Optional[bool] = None
|
||||
response_format: Optional[Union[Dict[str, str], Any]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
|
||||
@field_validator("response_format", mode="before")
|
||||
@classmethod
|
||||
def fields_type_checking(cls, response_format):
|
||||
if response_format and not isinstance(response_format, dict):
|
||||
from mistralai.models import ResponseFormat
|
||||
|
||||
if not isinstance(response_format, ResponseFormat):
|
||||
raise ValueError(
|
||||
f"The tool {response_format} should be an instance "
|
||||
"of `mistralai.models.ResponseFormat`."
|
||||
)
|
||||
return response_format
|
||||
|
||||
|
||||
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}
|
||||
59
camel/configs/modelscope_config.py
Normal file
59
camel/configs/modelscope_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class ModelScopeConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
ModelScope API. You can refer to the following link for more details:
|
||||
https://www.modelscope.cn/docs/model-service/API-Inference/intro
|
||||
|
||||
Args:
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` or
|
||||
specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "some_function"}}
|
||||
can be used to guide the model to use tools more strongly.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): Specifies the maximum number of tokens
|
||||
the model can generate. This sets an upper limit, but does not
|
||||
guarantee that this number will always be reached.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): Controls the randomness of the generated
|
||||
results. Lower values lead to less randomness, while higher
|
||||
values increase randomness. (default: :obj:`None`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
stream (bool, optional): If True, enables streaming output.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
|
||||
MODELSCOPE_API_PARAMS = {
|
||||
param for param in ModelScopeConfig.model_fields.keys()
|
||||
}
|
||||
63
camel/configs/moonshot_config.py
Normal file
63
camel/configs/moonshot_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class MoonshotConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Moonshot API. You can refer to the following link for more details:
|
||||
https://platform.moonshot.cn/docs/api-reference
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Controls randomness in the response.
|
||||
Lower values make the output more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate.
|
||||
(default: :obj:`None`)
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`False`)
|
||||
tools (list, optional): List of tools that the model can use for
|
||||
function calling. Each tool should be a dictionary containing
|
||||
type, function name, description, and parameters.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): Controls diversity via nucleus sampling.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message.(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Penalty for new tokens based on
|
||||
whether they appear in the text so far.
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Penalty for new tokens based on
|
||||
their frequency in the text so far.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[Union[str, List[str]]], optional): Up to 4 sequences
|
||||
where the API will stop generating further tokens.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
tools: Optional[list] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
MOONSHOT_API_PARAMS = {param for param in MoonshotConfig.model_fields.keys()}
|
||||
70
camel/configs/nvidia_config.py
Normal file
70
camel/configs/nvidia_config.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NotGiven
|
||||
|
||||
|
||||
class NvidiaConfig(BaseConfig):
|
||||
r"""Configuration class for NVIDIA API models.
|
||||
|
||||
This class defines the configuration parameters for NVIDIA's language
|
||||
models, including temperature, sampling parameters, and response format
|
||||
settings.
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`None`)
|
||||
temperature (float, optional): Controls randomness in the response.
|
||||
Higher values make output more random, lower values make it more
|
||||
deterministic. Range: [0.0, 2.0]. (default: :obj:`None`)
|
||||
top_p (float, optional): Controls diversity via nucleus sampling.
|
||||
Range: [0.0, 1.0]. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Penalizes new tokens based on
|
||||
whether they appear in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Penalizes new tokens based on
|
||||
their frequency in the text so far. Range: [-2.0, 2.0].
|
||||
(default: :obj:`None`)
|
||||
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
|
||||
to generate. If not provided, model will use its default maximum.
|
||||
(default: :obj:`None`)
|
||||
seed (Optional[int], optional): Random seed for deterministic sampling.
|
||||
(default: :obj:`None`)
|
||||
tools (Optional[List[Dict]], optional): List of tools available to the
|
||||
model. This includes tools such as a text editor, a calculator, or
|
||||
a search engine. (default: :obj:`None`)
|
||||
tool_choice (Optional[str], optional): Tool choice configuration.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[List[str]], optional): List of stop sequences.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = Field(default=None)
|
||||
temperature: Optional[float] = Field(default=None)
|
||||
top_p: Optional[float] = Field(default=None)
|
||||
presence_penalty: Optional[float] = Field(default=None)
|
||||
frequency_penalty: Optional[float] = Field(default=None)
|
||||
max_tokens: Optional[Union[int, NotGiven]] = Field(default=None)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
tool_choice: Optional[str] = Field(default=None)
|
||||
stop: Optional[List[str]] = Field(default=None)
|
||||
|
||||
|
||||
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}
|
||||
83
camel/configs/ollama_config.py
Normal file
83
camel/configs/ollama_config.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class OllamaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Union[Type[BaseModel], dict]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
|
||||
|
||||
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}
|
||||
125
camel/configs/openai_config.py
Normal file
125
camel/configs/openai_config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class ChatGPTConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`None`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
reasoning_effort(str, optional): A parameter specifying the level of
|
||||
reasoning used by certain model types. Valid values are :obj:
|
||||
`"low"`, :obj:`"medium"`, or :obj:`"high"`. If set, it is only
|
||||
applied to the model types that support it (e.g., :obj:`o1`,
|
||||
:obj:`o1mini`, :obj:`o1preview`, :obj:`o3mini`). If not provided
|
||||
or if the model type does not support it, this parameter is
|
||||
ignored. (default: :obj:`None`)
|
||||
parallel_tool_calls (bool, optional): A parameter specifying whether
|
||||
the model should call tools in parallel or not.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Union[Type[BaseModel], Dict]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict] = None
|
||||
user: Optional[str] = None
|
||||
tool_choice: Optional[Union[Dict[str, str], str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
parallel_tool_calls: Optional[bool] = None
|
||||
|
||||
|
||||
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
|
||||
106
camel/configs/openrouter_config.py
Normal file
106
camel/configs/openrouter_config.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NotGiven
|
||||
|
||||
|
||||
class OpenRouterConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility.
|
||||
|
||||
Reference: https://openrouter.ai/docs/api-reference/parameters
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported. (default: :obj:`None`)
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str], NotGiven]] = None
|
||||
max_tokens: Optional[Union[int, NotGiven]] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Union[dict, NotGiven]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
OPENROUTER_API_PARAMS = {
|
||||
param for param in OpenRouterConfig.model_fields.keys()
|
||||
}
|
||||
102
camel/configs/ppio_config.py
Normal file
102
camel/configs/ppio_config.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class PPIOConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`None`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Union[Type[BaseModel], Dict]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
PPIO_API_PARAMS = {param for param in PPIOConfig.model_fields.keys()}
|
||||
91
camel/configs/qwen_config.py
Normal file
91
camel/configs/qwen_config.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class QwenConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Qwen API. You can refer to the following link for more details:
|
||||
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
|
||||
|
||||
Args:
|
||||
stream (bool, optional): Whether to stream the response.
|
||||
(default: :obj:`None`)
|
||||
temperature (float, optional): Controls the diversity and
|
||||
focus of the generated results. Lower values make the output more
|
||||
focused, while higher values make it more diverse.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): Controls the diversity and focus of
|
||||
the generated results. Higher values make the output more diverse,
|
||||
while lower values make it more focused. (default: :obj:`0.9`)
|
||||
presence_penalty (float, optional): Controls the repetition
|
||||
content in the generated results. Positive values reduce the
|
||||
repetition of content, while negative values increase it.
|
||||
(default: :obj:`None`)
|
||||
response_format (Optional[Dict[str, str]], optional): Specifies the
|
||||
format of the returned content. The available values are
|
||||
`{"type": "text"}` or `{"type": "json_object"}`. Setting it to
|
||||
`{"type": "json_object"}` will output a standard JSON string.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (Optional[int], optional): Allows the model to
|
||||
generate the maximum number of tokens.
|
||||
(default: :obj:`None`)
|
||||
seed (Optional[int], optional): Sets the seed parameter to make the
|
||||
text generation process more deterministic, typically used to
|
||||
ensure that the results are consistent across model runs. By
|
||||
passing the same seed value (specified by you) in each model call
|
||||
while keeping other parameters unchanged, the model is likely to
|
||||
return the same result.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[Union[str, List]], optional): Using the stop parameter,
|
||||
the model will automatically stop generating text when it is about
|
||||
to include the specified string or token_id. You can use the stop
|
||||
parameter to control the output of the model by passing sensitive
|
||||
words. (default: :obj:`None`)
|
||||
tools (List, optional): Specifies an array of tools that the model can
|
||||
call. It can contain one or more tool objects. During a function
|
||||
call process, the model will select one tool from the array.
|
||||
(default: :obj:`None`)
|
||||
extra_body (Optional[Dict[str, Any]], optional): Additional parameters
|
||||
to be sent to the Qwen API. If you want to enable internet search,
|
||||
you can set this parameter to `{"enable_search": True}`.
|
||||
(default: :obj:`None`)
|
||||
include_usage (bool, optional): When streaming, specifies whether to
|
||||
include usage information in `stream_options`.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[Dict[str, str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List]] = None
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(self, include_usage: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Only set stream_options when stream is True
|
||||
# Otherwise, it will raise error when calling the API
|
||||
if self.stream:
|
||||
self.stream_options = {"include_usage": include_usage}
|
||||
|
||||
|
||||
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}
|
||||
69
camel/configs/reka_config.py
Normal file
69
camel/configs/reka_config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class RekaConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Reka API.
|
||||
|
||||
Reference: https://docs.reka.ai/api-reference/chat/create
|
||||
|
||||
Args:
|
||||
temperature (Optional[float], optional): temperature the temperature
|
||||
to use for sampling, e.g. 0.5. (default: :obj:`None`)
|
||||
top_p (Optional[float], optional): the cumulative probability of
|
||||
tokens to generate, e.g. 0.9. (default: :obj:`None`)
|
||||
top_k (Optional[int], optional): Parameter which forces the model to
|
||||
only consider the tokens with the `top_k` highest probabilities at
|
||||
the next step. (default: :obj:`None`)
|
||||
max_tokens (Optional[int], optional): the maximum number of tokens to
|
||||
generate, e.g. 100. (default: :obj:`None`)
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list. (default: :obj:`None`)
|
||||
seed (Optional[int], optional): the random seed to use for sampling, e.
|
||||
g. 42. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
use_search_engine (Optional[bool]): Whether to consider using search
|
||||
engine to complete the request. Note that even if this is set to
|
||||
`True`, the model might decide to not use search.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
use_search_engine: Optional[bool] = None
|
||||
|
||||
|
||||
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}
|
||||
164
camel/configs/samba_config.py
Normal file
164
camel/configs/samba_config.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN, NotGiven
|
||||
|
||||
|
||||
class SambaVerseAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
SambaVerse API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
top_k (int, optional): Only sample from the top K options for each
|
||||
subsequent token. Used to remove "long tail" low probability
|
||||
responses.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (Optional[int], optional): The maximum number of tokens to
|
||||
generate, e.g. 100.
|
||||
(default: :obj:`None`)
|
||||
repetition_penalty (Optional[float], optional): The parameter for
|
||||
repetition penalty. 1.0 means no penalty.
|
||||
(default: :obj:`None`)
|
||||
stop (Optional[Union[str,list[str]]]): Stop generation if this token
|
||||
is detected. Or if one of these tokens is detected when providing
|
||||
a string list.
|
||||
(default: :obj:`None`)
|
||||
stream (Optional[bool]): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
Currently SambaVerse API doesn't support stream mode.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
|
||||
SAMBA_VERSE_API_PARAMS = {
|
||||
param for param in SambaVerseAPIConfig().model_fields.keys()
|
||||
}
|
||||
|
||||
|
||||
class SambaCloudAPIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`0.2`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`1.0`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`1`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`False`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`0.0`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`0.0`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`""`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: float = 0.2 # openai default: 1.0
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
||||
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Union[dict, NotGiven] = NOT_GIVEN
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: str = ""
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
SAMBA_CLOUD_API_PARAMS = {
|
||||
param for param in SambaCloudAPIConfig().model_fields.keys()
|
||||
}
|
||||
76
camel/configs/sglang_config.py
Normal file
76
camel/configs/sglang_config.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class SGLangConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Reference: https://sgl-project.github.io/references/sampling_params.html
|
||||
|
||||
Args:
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
stream (bool, optional): Whether to stream the generated output in
|
||||
chunks. If set to `True`, the response will be streamed as it is
|
||||
generated. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[Dict[str, Any]], optional): A list of tool definitions
|
||||
that the model can dynamically invoke. Each tool should be
|
||||
defined as a dictionary following OpenAI's function calling
|
||||
specification format. For more details, refer to the OpenAI
|
||||
documentation. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tools: Optional[Union[List[Dict[str, Any]]]] = None
|
||||
|
||||
|
||||
SGLANG_API_PARAMS = {param for param in SGLangConfig.model_fields.keys()}
|
||||
92
camel/configs/siliconflow_config.py
Normal file
92
camel/configs/siliconflow_config.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
from camel.types import NOT_GIVEN
|
||||
|
||||
|
||||
class SiliconFlowConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
SiliconFlow API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Determines the degree of randomness
|
||||
in the response. (default: :obj:`None`)
|
||||
top_p (float, optional): The top_p (nucleus) parameter is used to
|
||||
dynamically adjust the number of choices for each predicted token
|
||||
based on the cumulative probabilities. (default: :obj:`None`)
|
||||
n (int, optional): Number of generations to return.
|
||||
(default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. (default: :obj:`None`)
|
||||
stream (bool, optional): If set, tokens are returned as Server-Sent
|
||||
Events as they are made available. (default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate.
|
||||
(default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
response_format: Optional[Union[Type[BaseModel], dict]] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
r"""Convert the current configuration to a dictionary.
|
||||
|
||||
This method converts the current configuration object to a dictionary
|
||||
representation, which can be used for serialization or other purposes.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary representation of the current
|
||||
configuration.
|
||||
"""
|
||||
config_dict = self.model_dump()
|
||||
if self.tools:
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
tools_schema = []
|
||||
for tool in self.tools:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
raise ValueError(
|
||||
f"The tool {tool} should "
|
||||
"be an instance of `FunctionTool`."
|
||||
)
|
||||
tools_schema.append(tool.get_openai_tool_schema())
|
||||
config_dict["tools"] = NOT_GIVEN
|
||||
return config_dict
|
||||
|
||||
|
||||
SILICONFLOW_API_PARAMS = {
|
||||
param for param in SiliconFlowConfig.model_fields.keys()
|
||||
}
|
||||
100
camel/configs/togetherai_config.py
Normal file
100
camel/configs/togetherai_config.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class TogetherAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`{}`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # openai default: 1.0
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
TOGETHERAI_API_PARAMS = {
|
||||
param for param in TogetherAIConfig.model_fields.keys()
|
||||
}
|
||||
110
camel/configs/vllm_config.py
Normal file
110
camel/configs/vllm_config.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
# flake8: noqa: E501
|
||||
class VLLMConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
OpenAI API.
|
||||
|
||||
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
n (int, optional): How many chat completion choices to generate for
|
||||
each input message. (default: :obj:`None`)
|
||||
response_format (object, optional): An object specifying the format
|
||||
that the model must output. Compatible with GPT-4 Turbo and all
|
||||
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
||||
{"type": "json_object"} enables JSON mode, which guarantees the
|
||||
message the model generates is valid JSON. Important: when using
|
||||
JSON mode, you must also instruct the model to produce JSON
|
||||
yourself via a system or user message. Without this, the model
|
||||
may generate an unending stream of whitespace until the generation
|
||||
reaches the token limit, resulting in a long-running and seemingly
|
||||
"stuck" request. Also note that the message content may be
|
||||
partially cut off if finish_reason="length", which indicates the
|
||||
generation exceeded max_tokens or the conversation exceeded the
|
||||
max context length.
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on whether
|
||||
they appear in the text so far, increasing the model's likelihood
|
||||
to talk about new topics. See more information about frequency and
|
||||
presence penalties. (default: :obj:`None`)
|
||||
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
||||
:obj:`2.0`. Positive values penalize new tokens based on their
|
||||
existing frequency in the text so far, decreasing the model's
|
||||
likelihood to repeat the same line verbatim. See more information
|
||||
about frequency and presence penalties. (default: :obj:`None`)
|
||||
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
||||
appearing in the completion. Accepts a json object that maps tokens
|
||||
(specified by their token ID in the tokenizer) to an associated
|
||||
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
||||
is added to the logits generated by the model prior to sampling.
|
||||
The exact effect will vary per model, but values between:obj:` -1`
|
||||
and :obj:`1` should decrease or increase likelihood of selection;
|
||||
values like :obj:`-100` or :obj:`100` should result in a ban or
|
||||
exclusive selection of the relevant token. (default: :obj:`None`)
|
||||
user (str, optional): A unique identifier representing your end-user,
|
||||
which can help OpenAI to monitor and detect abuse.
|
||||
(default: :obj:`None`)
|
||||
logprobs: Whether to return log probabilities of the output tokens or
|
||||
not. If true, returns the log probabilities of each output token
|
||||
returned in the `logits` of `message`. (default: :obj:`None`)
|
||||
top_logprobs: An integer between 0 and 20 specifying the number of
|
||||
most likely tokens to return at each token position, each with an
|
||||
associated log probability. `logprobs` must be set to `true` if
|
||||
this parameter is used. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # openai default: 1.0
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
response_format: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: dict = Field(default_factory=dict)
|
||||
user: Optional[str] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}
|
||||
57
camel/configs/yi_config.py
Normal file
57
camel/configs/yi_config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class YiConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using the
|
||||
Yi API. You can refer to the following link for more details:
|
||||
https://platform.lingyiwanwu.com/docs/api-reference
|
||||
|
||||
Args:
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` or
|
||||
specifying a particular tool via
|
||||
{"type": "function", "function": {"name": "some_function"}}
|
||||
can be used to guide the model to use tools more strongly.
|
||||
(default: :obj:`None`)
|
||||
max_tokens (int, optional): Specifies the maximum number of tokens
|
||||
the model can generate. This sets an upper limit, but does not
|
||||
guarantee that this number will always be reached.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): Controls the randomness of the generated
|
||||
results. Lower values lead to less randomness, while higher
|
||||
values increase randomness. (default: :obj:`None`)
|
||||
temperature (float, optional): Controls the diversity and focus of
|
||||
the generated results. Lower values make the output more focused,
|
||||
while higher values make it more diverse. (default: :obj:`0.3`)
|
||||
stream (bool, optional): If True, enables streaming output.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
|
||||
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}
|
||||
70
camel/configs/zhipuai_config.py
Normal file
70
camel/configs/zhipuai_config.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from camel.configs.base_config import BaseConfig
|
||||
|
||||
|
||||
class ZhipuAIConfig(BaseConfig):
|
||||
r"""Defines the parameters for generating chat completions using OpenAI
|
||||
compatibility
|
||||
|
||||
Reference: https://open.bigmodel.cn/dev/api#glm-4v
|
||||
|
||||
Args:
|
||||
temperature (float, optional): Sampling temperature to use, between
|
||||
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
||||
while lower values make it more focused and deterministic.
|
||||
(default: :obj:`None`)
|
||||
top_p (float, optional): An alternative to sampling with temperature,
|
||||
called nucleus sampling, where the model considers the results of
|
||||
the tokens with top_p probability mass. So :obj:`0.1` means only
|
||||
the tokens comprising the top 10% probability mass are considered.
|
||||
(default: :obj:`None`)
|
||||
stream (bool, optional): If True, partial message deltas will be sent
|
||||
as data-only server-sent events as they become available.
|
||||
(default: :obj:`None`)
|
||||
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
||||
will stop generating further tokens. (default: :obj:`None`)
|
||||
max_tokens (int, optional): The maximum number of tokens to generate
|
||||
in the chat completion. The total length of input tokens and
|
||||
generated tokens is limited by the model's context length.
|
||||
(default: :obj:`None`)
|
||||
tools (list[FunctionTool], optional): A list of tools the model may
|
||||
call. Currently, only functions are supported as a tool. Use this
|
||||
to provide a list of functions the model may generate JSON inputs
|
||||
for. A max of 128 functions are supported.
|
||||
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
||||
any) tool is called by the model. :obj:`"none"` means the model
|
||||
will not call any tool and instead generates a message.
|
||||
:obj:`"auto"` means the model can pick between generating a
|
||||
message or calling one or more tools. :obj:`"required"` means the
|
||||
model must call one or more tools. Specifying a particular tool
|
||||
via {"type": "function", "function": {"name": "my_function"}}
|
||||
forces the model to call that tool. :obj:`"none"` is the default
|
||||
when no tools are present. :obj:`"auto"` is the default if tools
|
||||
are present.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[Union[str, Sequence[str]]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tool_choice: Optional[Union[dict[str, str], str]] = None
|
||||
|
||||
|
||||
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}
|
||||
19
camel/data_collector/__init__.py
Normal file
19
camel/data_collector/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .alpaca_collector import AlpacaDataCollector
|
||||
from .base import BaseDataCollector
|
||||
from .sharegpt_collector import ShareGPTDataCollector
|
||||
|
||||
__all__ = ["BaseDataCollector", "AlpacaDataCollector", "ShareGPTDataCollector"]
|
||||
127
camel/data_collector/alpaca_collector.py
Normal file
127
camel/data_collector/alpaca_collector.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.data_collector.base import BaseDataCollector
|
||||
from camel.messages import AlpacaItem, BaseMessage
|
||||
from camel.schemas import OpenAISchemaConverter
|
||||
|
||||
# ruff: noqa: E501
|
||||
DEFAULT_CONVERTER_PROMPTS = """
|
||||
Extract key entities and attributes from the conversations
|
||||
and convert them into a structured JSON format.
|
||||
For example:
|
||||
Instruction: You are a helpful assistant.
|
||||
User: When is the release date of the video game Portal?
|
||||
Assistant: The release date of the video game Portal is October 9.
|
||||
Your output should be:
|
||||
{
|
||||
"instruction": "You are a helpful assistant. When is the release date of the video game Portal?",
|
||||
"input": "",
|
||||
"output": "The release date of the video game Portal is October 9."
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class AlpacaDataCollector(BaseDataCollector):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.system_message: Optional[BaseMessage] = None
|
||||
self.agent_name: Optional[str] = None
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent: Union[List[ChatAgent], ChatAgent],
|
||||
) -> Self:
|
||||
r"""Inject an agent into the data collector.
|
||||
|
||||
Args:
|
||||
agent (Union[List[ChatAgent], ChatAgent]):
|
||||
The agent to inject.
|
||||
"""
|
||||
if not self.agent_name:
|
||||
_agent = agent if isinstance(agent, ChatAgent) else agent[0]
|
||||
self.agent_name = _agent.role_name
|
||||
self.system_message = _agent._system_message
|
||||
super().record(agent)
|
||||
return self
|
||||
|
||||
def convert(self) -> Dict[str, Any]:
|
||||
r"""Convert the collected data into a dictionary."""
|
||||
if self.agent_name is None:
|
||||
raise ValueError("No agent injected")
|
||||
|
||||
history = self.get_agent_history(self.agent_name)
|
||||
if not history:
|
||||
raise ValueError("No data collected.")
|
||||
|
||||
# Validate and process history
|
||||
if len(history) == 3 and history[0].role == "system":
|
||||
history = history[1:] # Ignore the system message.
|
||||
elif len(history) != 2:
|
||||
raise ValueError(
|
||||
f"AlpacaDataCollector only supports one message pair, but "
|
||||
f"got {len(history)}"
|
||||
)
|
||||
|
||||
input_message, output_message = history
|
||||
instruction = (
|
||||
self.system_message.content if self.system_message else ""
|
||||
) + str(input_message.message)
|
||||
|
||||
data = {
|
||||
"instruction": instruction,
|
||||
"input": "",
|
||||
"output": output_message.message,
|
||||
}
|
||||
self.data.append(data)
|
||||
return data
|
||||
|
||||
def llm_convert(
|
||||
self,
|
||||
converter: Optional[OpenAISchemaConverter] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
r"""Convert collected data using an LLM schema converter.
|
||||
|
||||
Args:
|
||||
converter (Optional[OpenAISchemaConverter], optional):
|
||||
The converter to use. (default: :obj:`OpenAISchemaConverter`)
|
||||
prompt (Optional[str], optional): Prompt to guide the conversion.
|
||||
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: The converted data.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent is injected or data cannot be collected.
|
||||
"""
|
||||
prompt = prompt or DEFAULT_CONVERTER_PROMPTS
|
||||
converter = converter or OpenAISchemaConverter()
|
||||
|
||||
system = self.system_message.content if self.system_message else ""
|
||||
context = [f"Instruction: {system}\n"]
|
||||
|
||||
for message in self.get_agent_history(str(self.agent_name)):
|
||||
if message.role == "user":
|
||||
context.append(f"User: {message.message}\n")
|
||||
else:
|
||||
context.append(f"{message.name}: {message.message}\n")
|
||||
return converter.convert(
|
||||
"\n".join(context), AlpacaItem, prompt=prompt
|
||||
).model_dump()
|
||||
211
camel/data_collector/base.py
Normal file
211
camel/data_collector/base.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
|
||||
|
||||
class CollectorData:
|
||||
def __init__(
|
||||
self,
|
||||
id: UUID,
|
||||
name: str,
|
||||
role: Literal["user", "assistant", "system", "tool"],
|
||||
message: Optional[str] = None,
|
||||
function_call: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
r"""Create a data item store information about a message.
|
||||
Used by the data collector.
|
||||
|
||||
Args:
|
||||
|
||||
id (UUID): The id of the message.
|
||||
name (str): The name of the agent.
|
||||
role (Literal["user", "assistant", "system", "function"]):
|
||||
The role of the message.
|
||||
message (Optional[str], optional): The message.
|
||||
(default: :obj:`None`)
|
||||
function_call (Optional[Dict[str, Any]], optional):
|
||||
The function call. (default: :obj:`None`)
|
||||
|
||||
Raises:
|
||||
|
||||
ValueError: If the role is not supported.
|
||||
ValueError: If the role is system and function call is provided.
|
||||
ValueError: If neither message nor function call is provided.
|
||||
|
||||
"""
|
||||
if role not in ["user", "assistant", "system", "tool"]:
|
||||
raise ValueError(f"Role {role} not supported")
|
||||
if role == "system" and function_call:
|
||||
raise ValueError("System role cannot have function call")
|
||||
if not message and not function_call:
|
||||
raise ValueError(
|
||||
"Either message or function call must be provided"
|
||||
)
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.message = message
|
||||
self.function_call = function_call
|
||||
|
||||
@staticmethod
|
||||
def from_context(name, context: Dict[str, Any]) -> "CollectorData":
|
||||
r"""Create a data collector from a context.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
context (Dict[str, Any]): The context.
|
||||
|
||||
Returns:
|
||||
CollectorData: The data collector.
|
||||
"""
|
||||
return CollectorData(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
role=context["role"],
|
||||
message=context["content"],
|
||||
function_call=context.get("tool_calls", None),
|
||||
)
|
||||
|
||||
|
||||
class BaseDataCollector(ABC):
|
||||
r"""Base class for data collectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
r"""Create a data collector."""
|
||||
self.history: List[CollectorData] = []
|
||||
self._recording = False
|
||||
self.agents: List[Tuple[str, ChatAgent]] = []
|
||||
self.data: List[Dict[str, Any]] = []
|
||||
|
||||
def step(
|
||||
self,
|
||||
role: Literal["user", "assistant", "system", "tool"],
|
||||
name: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
function_call: Optional[Dict[str, Any]] = None,
|
||||
) -> Self:
|
||||
r"""Record a message.
|
||||
|
||||
Args:
|
||||
role (Literal["user", "assistant", "system", "tool"]):
|
||||
The role of the message.
|
||||
name (Optional[str], optional): The name of the agent.
|
||||
(default: :obj:`None`)
|
||||
message (Optional[str], optional): The message to record.
|
||||
(default: :obj:`None`)
|
||||
function_call (Optional[Dict[str, Any]], optional):
|
||||
The function call to record. (default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Self: The data collector.
|
||||
|
||||
"""
|
||||
|
||||
name = name or role
|
||||
|
||||
self.history.append(
|
||||
CollectorData(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
role=role,
|
||||
message=message,
|
||||
function_call=function_call,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent: Union[List[ChatAgent], ChatAgent],
|
||||
) -> Self:
|
||||
r"""Record agents.
|
||||
|
||||
Args:
|
||||
agent (Union[List[ChatAgent], ChatAgent]):
|
||||
The agent(s) to inject.
|
||||
"""
|
||||
if not isinstance(agent, list):
|
||||
agent = [agent]
|
||||
for a in agent:
|
||||
name = a.role_name
|
||||
if not name:
|
||||
name = f"{a.__class__.__name__}_{len(self.agents)}"
|
||||
if name in [n for n, _ in self.agents]:
|
||||
raise ValueError(f"Name {name} already exists")
|
||||
|
||||
self.agents.append((name, a))
|
||||
return self
|
||||
|
||||
def start(self) -> Self:
|
||||
r"""Start recording."""
|
||||
self._recording = True
|
||||
return self
|
||||
|
||||
def stop(self) -> Self:
|
||||
r"""Stop recording."""
|
||||
self._recording = False
|
||||
return self
|
||||
|
||||
@property
|
||||
def recording(self) -> bool:
|
||||
r"""Whether the collector is recording."""
|
||||
return self._recording
|
||||
|
||||
def reset(self, reset_agents: bool = True):
|
||||
r"""Reset the collector.
|
||||
|
||||
Args:
|
||||
reset_agents (bool, optional):
|
||||
Whether to reset the agents. Defaults to True.
|
||||
"""
|
||||
self.history = []
|
||||
if reset_agents:
|
||||
for _, agent in self.agents:
|
||||
agent.reset()
|
||||
|
||||
@abstractmethod
|
||||
def convert(self) -> Any:
|
||||
r"""Convert the collected data."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def llm_convert(self, converter: Any, prompt: Optional[str] = None) -> Any:
|
||||
r"""Convert the collected data."""
|
||||
pass
|
||||
|
||||
def get_agent_history(self, name: str) -> List[CollectorData]:
|
||||
r"""Get the message history of an agent.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
|
||||
Returns:
|
||||
List[CollectorData]: The message history of the agent
|
||||
"""
|
||||
if not self.history:
|
||||
for _name, agent in self.agents:
|
||||
if _name == name:
|
||||
return [
|
||||
CollectorData.from_context(name, dict(i))
|
||||
for i in agent.memory.get_context()[0]
|
||||
]
|
||||
return [msg for msg in self.history if msg.name == name]
|
||||
216
camel/data_collector/sharegpt_collector.py
Normal file
216
camel/data_collector/sharegpt_collector.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.data_collector.base import BaseDataCollector
|
||||
from camel.messages import BaseMessage
|
||||
from camel.messages.conversion.conversation_models import (
|
||||
ShareGPTConversation,
|
||||
ShareGPTMessage,
|
||||
)
|
||||
from camel.schemas import OpenAISchemaConverter
|
||||
from camel.toolkits import FunctionTool
|
||||
|
||||
FROM_HASH = {
|
||||
"human": "human",
|
||||
"gpt": "gpt",
|
||||
"observation": "human",
|
||||
"function_call": "gpt",
|
||||
}
|
||||
# ruff: noqa: E501
|
||||
DEFAULT_CONVERTER_PROMPTS = """
|
||||
Extract key entities and attributes from the conversations
|
||||
and convert them into a structured JSON format.
|
||||
For example:
|
||||
System: You are a helpful assistant
|
||||
Tools: [{"name": "get_release_date", "arguments": ["Portal"]}]
|
||||
User: When is the release date of the video game Portal?
|
||||
Assistant: The release date of the video game Portal is October 9, 2007.
|
||||
Your output should be:
|
||||
{
|
||||
"system": "You are a helpful assistant",
|
||||
"tools": "[{"name": "get_release_date", "arguments": ["Portal"]}]",
|
||||
"conversations": [
|
||||
{"from": "human", "value": "When is the release date of the video game Portal?"},
|
||||
{"from": "gpt", "value": "The release date of the video game Portal is October 9, 2007."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class ConversationItem(BaseModel):
|
||||
from_: Literal["human", "gpt", "function_call", "observation"]
|
||||
value: str
|
||||
|
||||
class Config:
|
||||
fields: ClassVar[Dict[str, str]] = {"from_": "from"}
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class ShareGPTData(BaseModel):
|
||||
system: str
|
||||
tools: str
|
||||
conversations: List[ConversationItem]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class ShareGPTDataCollector(BaseDataCollector):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.system_message: Optional[BaseMessage] = None
|
||||
self.agent_name: Optional[str] = None
|
||||
self.tools: List[FunctionTool] = []
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent: Union[List[ChatAgent], ChatAgent],
|
||||
) -> Self:
|
||||
r"""Inject an agent into the data collector."""
|
||||
if not self.agent_name:
|
||||
_agent = agent if isinstance(agent, ChatAgent) else agent[0]
|
||||
self.agent_name = _agent.role_name
|
||||
self.system_message = _agent._system_message
|
||||
self.tools += list(_agent.tool_dict.values())
|
||||
|
||||
super().record(agent)
|
||||
return self
|
||||
|
||||
def convert(self) -> Dict[str, Any]:
|
||||
r"""Convert the collected data into a dictionary."""
|
||||
if self.agent_name is None:
|
||||
raise ValueError("No agent injected")
|
||||
|
||||
history = self.get_agent_history(self.agent_name)
|
||||
if not history:
|
||||
raise ValueError("No data collected.")
|
||||
|
||||
data = dict(
|
||||
system=self.system_message.content if self.system_message else "",
|
||||
tools=json.dumps(
|
||||
[t.get_openai_tool_schema()["function"] for t in self.tools]
|
||||
),
|
||||
ensure_ascii=False,
|
||||
conversations=[],
|
||||
)
|
||||
|
||||
conversations: List[Any] = []
|
||||
for _data in history:
|
||||
role, message = _data.role, _data
|
||||
|
||||
if role == "user":
|
||||
conversations.append(
|
||||
{"from": "human", "value": message.message}
|
||||
)
|
||||
elif role == "assistant":
|
||||
if message.function_call:
|
||||
conversations.append(
|
||||
{
|
||||
"from": "function_call",
|
||||
"value": json.dumps(
|
||||
message.function_call, ensure_ascii=False
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
conversations.append(
|
||||
{"from": "gpt", "value": message.message}
|
||||
)
|
||||
elif role == "function" or role == "tool":
|
||||
conversations.append(
|
||||
{
|
||||
"from": "observation",
|
||||
"value": json.dumps(
|
||||
message.message, ensure_ascii=False
|
||||
), # type: ignore[attr-defined]
|
||||
}
|
||||
)
|
||||
data["conversations"] = conversations
|
||||
|
||||
self.data.append(data)
|
||||
return data
|
||||
|
||||
def llm_convert(
|
||||
self,
|
||||
converter: Optional[OpenAISchemaConverter] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
r"""Convert collected data using an LLM schema converter.
|
||||
|
||||
Args:
|
||||
converter (Optional[OpenAISchemaConverter], optional):
|
||||
The converter to use. (default: :obj:`OpenAISchemaConverter`)
|
||||
prompt (Optional[str], optional): Prompt to guide the conversion.
|
||||
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: The converted data.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent is injected or data cannot be collected.
|
||||
"""
|
||||
prompt = prompt or DEFAULT_CONVERTER_PROMPTS
|
||||
converter = converter or OpenAISchemaConverter()
|
||||
|
||||
system = self.system_message.content if self.system_message else ""
|
||||
context = [f"System: {system}\n"]
|
||||
|
||||
context.append(
|
||||
"Tools: "
|
||||
+ json.dumps(
|
||||
[t.get_openai_tool_schema()["function"] for t in self.tools],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
for _data in self.get_agent_history(str(self.agent_name)):
|
||||
role, message = _data.role, _data
|
||||
prefix = (
|
||||
f"{role}: " if role != "user" else "User: " + f"{_data.name}: "
|
||||
)
|
||||
if message.function_call:
|
||||
context.append(
|
||||
prefix
|
||||
+ json.dumps(message.function_call, ensure_ascii=False)
|
||||
)
|
||||
|
||||
elif role == "function" or role == "tool":
|
||||
context.append(
|
||||
prefix + json.dumps(message.message, ensure_ascii=False)
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
context.append(prefix + str(message.message))
|
||||
return converter.convert(
|
||||
"\n".join(context), ShareGPTData, prompt
|
||||
).model_dump()
|
||||
|
||||
@staticmethod
|
||||
def to_sharegpt_conversation(data: Dict[str, Any]) -> ShareGPTConversation:
|
||||
messages = [
|
||||
ShareGPTMessage(from_="system", value=data["system"]) # type: ignore[call-arg]
|
||||
]
|
||||
for item in data["conversations"]:
|
||||
messages.append(
|
||||
ShareGPTMessage( # type: ignore[call-arg]
|
||||
from_=FROM_HASH[item["from"]],
|
||||
value=item["value"],
|
||||
)
|
||||
)
|
||||
return ShareGPTConversation(root=messages)
|
||||
23
camel/datagen/__init__.py
Normal file
23
camel/datagen/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .cot_datagen import CoTDataGenerator
|
||||
from .self_improving_cot import SelfImprovingCoTPipeline
|
||||
from .self_instruct import SelfInstructPipeline
|
||||
|
||||
__all__ = [
|
||||
"CoTDataGenerator",
|
||||
"SelfInstructPipeline",
|
||||
"SelfImprovingCoTPipeline",
|
||||
]
|
||||
448
camel/datagen/cot_datagen.py
Normal file
448
camel/datagen/cot_datagen.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
|
||||
# Get a logger for this module
|
||||
logger = get_logger('CoTDataGenerator')
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
r"""Model for structured agent responses.
|
||||
|
||||
A Pydantic model class that represents structured responses from agents,
|
||||
including a similarity score that measures the quality of the response.
|
||||
|
||||
Args:
|
||||
score (float): A similarity score between 0 and 1 that compares the
|
||||
current answer to the correct answer. Must be within the range
|
||||
[0, 1].
|
||||
"""
|
||||
|
||||
score: Annotated[float, confloat(ge=0, le=1)] = Field(
|
||||
...,
|
||||
description="""Similarity score between 0 and 1
|
||||
comparing current answer to correct answer""",
|
||||
)
|
||||
|
||||
|
||||
class VerificationResponse(BaseModel):
|
||||
r"""Model for structured verification responses.
|
||||
|
||||
A Pydantic model class that represents verification results from agents,
|
||||
indicating whether an answer is correct or not.
|
||||
|
||||
Args:
|
||||
is_correct (bool): Boolean indicating if the answer is correct.
|
||||
"""
|
||||
|
||||
is_correct: bool = Field(
|
||||
...,
|
||||
description="Boolean indicating if the answer is correct",
|
||||
)
|
||||
|
||||
|
||||
class CoTDataGenerator:
|
||||
r"""Class for generating and managing data through chat agent interactions.
|
||||
|
||||
This module implements a sophisticated Chain of Thought data generation
|
||||
system that combines several key algorithms to produce high-quality
|
||||
reasoning paths. Methods implemented:
|
||||
|
||||
1. Monte Carlo Tree Search (MCTS)
|
||||
2. Binary Search Error Detection
|
||||
3. Dual-Agent Verification System
|
||||
4. Solution Tree Management
|
||||
|
||||
Args:
|
||||
chat_agent (Optional[ChatAgent]): Optional single agent
|
||||
for both tasks (legacy mode). (default::obj:`None`)
|
||||
generator_agent (Optional[ChatAgent]): Optional specialized agent for
|
||||
answer generation. (default::obj:`None`)
|
||||
verifier_agent (Optional[ChatAgent]): Optional specialized agent for
|
||||
answer verification. (default::obj:`None`)
|
||||
golden_answers (Dict[str, str]): Dictionary containing pre-defined
|
||||
correct answers for validation and comparison. Required for answer
|
||||
verification.
|
||||
search_limit (int): Maximum number of search iterations allowed.
|
||||
(default::obj:`100`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_agent: Optional[ChatAgent] = None,
|
||||
*,
|
||||
generator_agent: Optional[ChatAgent] = None,
|
||||
verifier_agent: Optional[ChatAgent] = None,
|
||||
golden_answers: Dict[str, str],
|
||||
search_limit: int = 100,
|
||||
):
|
||||
r"""Initialize the CoTDataGenerator.
|
||||
|
||||
This constructor supports both single-agent and dual-agent modes:
|
||||
1. Single-agent mode (legacy): Pass a single chat_agent that will be
|
||||
used for both generation and verification.
|
||||
2. Dual-agent mode: Pass separate generator_agent and verifier_agent
|
||||
for specialized tasks.
|
||||
|
||||
Args:
|
||||
chat_agent (Optional[ChatAgent]): Optional single agent for both
|
||||
tasks (legacy mode). (default::obj:`None`)
|
||||
generator_agent (Optional[ChatAgent]): Optional specialized agent
|
||||
for answer generation. (default::obj:`None`)
|
||||
verifier_agent (Optional[ChatAgent]): Optional specialized agent
|
||||
for answer verification. (default::obj:`None`)
|
||||
golden_answers (Dict[str, str]): Dictionary containing pre-defined
|
||||
correct answers for validation and comparison. Required for
|
||||
answer verification.
|
||||
search_limit (int): Maximum number of search iterations allowed.
|
||||
(default::obj:`100`)
|
||||
"""
|
||||
if chat_agent is not None:
|
||||
if generator_agent is not None or verifier_agent is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both chat_agent \
|
||||
and generator/verifier agents"
|
||||
)
|
||||
self.generator_agent = chat_agent
|
||||
self.verifier_agent = chat_agent
|
||||
else:
|
||||
if generator_agent is None or verifier_agent is None:
|
||||
raise ValueError(
|
||||
"Must specify either chat_agent or both generator and "
|
||||
"verifier agents"
|
||||
)
|
||||
self.generator_agent = generator_agent
|
||||
self.verifier_agent = verifier_agent
|
||||
|
||||
self.golden_answers = golden_answers
|
||||
self.search_limit = search_limit
|
||||
self.solution_tree: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||
logger.info(
|
||||
"CoTDataGenerator initialized with search_limit=%d", search_limit
|
||||
)
|
||||
|
||||
def get_answer(self, question: str, context: str = "") -> str:
|
||||
r"""Get an answer from the chat agent for a given question.
|
||||
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
context (str): Additional context for the question.
|
||||
(default::obj:`""`)
|
||||
|
||||
Returns:
|
||||
str: The generated answer.
|
||||
"""
|
||||
prompt = f"""
|
||||
Please think step by step and solve this problem: {question}
|
||||
Existing content: {context}
|
||||
Requirements:
|
||||
1. Analyze the problem requirements
|
||||
2. List the steps to solve the problem
|
||||
3. Execute the solution process
|
||||
4. Provide the final answer
|
||||
Please explain the thought process of each step in detail.
|
||||
"""
|
||||
self.generator_agent.reset()
|
||||
response = self.generator_agent.step(prompt)
|
||||
answer = response.msgs[0].content
|
||||
logger.info("AI thought process:\n%s", answer)
|
||||
return answer
|
||||
|
||||
def verify_answer(self, question: str, answer: str) -> bool:
|
||||
r"""Verify if a generated answer is semantically equivalent to
|
||||
the golden answer for a given question.
|
||||
|
||||
Args:
|
||||
question (str): The question being answered.
|
||||
answer (str): The answer to verify.
|
||||
|
||||
Returns:
|
||||
bool: True if the answer matches the golden answer based on
|
||||
semantic equivalence (meaning the core content and meaning are
|
||||
the same, even if the exact wording differs).
|
||||
False in the following cases:
|
||||
- If the provided question doesn't exist in the golden answers
|
||||
- If the answer's meaning differs from the golden answer
|
||||
"""
|
||||
golden_answer = self.golden_answers.get(question)
|
||||
if not golden_answer:
|
||||
raise ValueError(
|
||||
f"No golden answer found for question: {question}"
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"Question: {question}\n"
|
||||
f"Student Answer: {answer}\n"
|
||||
f"Correct Answer: {golden_answer}\n"
|
||||
"Is the student's answer correct? Please respond with 'true' or "
|
||||
"'false' only."
|
||||
)
|
||||
self.verifier_agent.reset()
|
||||
response = self.verifier_agent.step(
|
||||
prompt, response_format=VerificationResponse
|
||||
)
|
||||
is_correct = response.msgs[0].parsed.is_correct # type:ignore [union-attr]
|
||||
logger.info("Answer verification result: %s", is_correct)
|
||||
return is_correct
|
||||
|
||||
def monte_carlo_tree_search(
|
||||
self, question: str, partial_solution: str = ""
|
||||
) -> float:
|
||||
r"""Perform Monte Carlo Tree Search to find the best solution.
|
||||
|
||||
Process:
|
||||
a. Selection: Choose promising partial solutions based on previous
|
||||
scores
|
||||
b. Expansion: Generate new solution steps using the generator agent
|
||||
c. Simulation: Evaluate solution quality using similarity scores
|
||||
d. Backpropagation: Update solution tree with new findings
|
||||
|
||||
Args:
|
||||
question (str): The question to solve.
|
||||
partial_solution (str): The current partial solution.
|
||||
(default::obj:`""`)
|
||||
|
||||
Returns:
|
||||
float: The similarity score between the current
|
||||
solution and golden answer.
|
||||
"""
|
||||
if question not in self.golden_answers:
|
||||
raise ValueError(
|
||||
f"No golden answer found for question: {question}"
|
||||
)
|
||||
|
||||
golden_answer = self.golden_answers[question]
|
||||
|
||||
prompt = (
|
||||
f"Please evaluate this solution and "
|
||||
f"give a score between 0-1:\n"
|
||||
f"Question: {question}\n"
|
||||
f"Solution: {partial_solution}\n"
|
||||
f"Correct answer: {golden_answer}\n"
|
||||
f"Return a JSON object with a single field 'score' containing "
|
||||
f"a float between 0 and 1, like this: {{'score': 0.85}}\n"
|
||||
)
|
||||
self.generator_agent.reset()
|
||||
response = self.generator_agent.step(
|
||||
prompt, response_format=AgentResponse
|
||||
)
|
||||
agent_response = response.msgs[0].parsed.score # type: ignore [union-attr]
|
||||
|
||||
return agent_response
|
||||
|
||||
def binary_search_error(self, question: str, solution: str) -> int:
|
||||
r"""Use binary search to locate the first error in the solution.
|
||||
This method splits the solution into sentences using both English and
|
||||
Chinese sentence delimiters and performs binary search to find the
|
||||
first error.
|
||||
|
||||
Args:
|
||||
question (str): The question being solved.
|
||||
solution (str): The complete solution to analyze.
|
||||
|
||||
Returns:
|
||||
int: The position of the first error found in the solution.
|
||||
Returns -1. If no errors are found (all sentences are correct).
|
||||
"""
|
||||
logger.info("Starting binary search for error location")
|
||||
# Split by both English period and Chinese period
|
||||
sentences = [
|
||||
s.strip()
|
||||
for s in solution.replace('。', '.').split('.')
|
||||
if s.strip()
|
||||
]
|
||||
|
||||
# First check if the entire solution is correct
|
||||
if self.verify_answer(question, solution):
|
||||
return -1
|
||||
|
||||
left, right = 0, len(sentences)
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
partial_solution = '. '.join(sentences[:mid]) + '.'
|
||||
logger.info("Checking solution fragment:\n%s", partial_solution)
|
||||
# Verify if the current part is correct
|
||||
is_correct = self.verify_answer(question, partial_solution)
|
||||
if is_correct:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
logger.info("First error position found: sentence %d", left)
|
||||
return left
|
||||
|
||||
def solve(self, question: str) -> str:
|
||||
r"""Solve a question using a multi-step approach.
|
||||
|
||||
The solution process follows these steps:
|
||||
1. Try to solve directly - if correct, return the solution
|
||||
2. If not correct, use Monte Carlo Tree Search to find a good solution
|
||||
3. If the solution isn't perfect, use binary search to locate errors
|
||||
4. Generate a new solution based on the correct part
|
||||
|
||||
Args:
|
||||
question (str): The question to solve.
|
||||
|
||||
Returns:
|
||||
str: The best solution found.
|
||||
"""
|
||||
# 1. Try direct solution first
|
||||
solution = self.get_answer(question)
|
||||
if self.verify_answer(question, solution):
|
||||
logger.info("Initial solution is correct")
|
||||
return solution
|
||||
|
||||
# 2. If direct solution fails, try Monte Carlo Tree Search
|
||||
# to find a solution with high similarity score
|
||||
best_solution = ""
|
||||
best_score: float = 0.0
|
||||
for i in range(self.search_limit):
|
||||
# Generate new answer
|
||||
current_solution = self.get_answer(question, best_solution)
|
||||
|
||||
# Evaluate solution similarity score
|
||||
prompt = (
|
||||
f"Please evaluate this solution and "
|
||||
f"give a score between 0-1:\n"
|
||||
f"Question: {question}\n"
|
||||
f"Solution: {current_solution}\n"
|
||||
f"Correct answer: {self.golden_answers.get(question, '')}\n"
|
||||
f"Return a JSON object with a single field 'score' containing "
|
||||
f"a float between 0 and 1, like this: {{'score': 0.85}}\n"
|
||||
)
|
||||
self.generator_agent.reset()
|
||||
response = self.generator_agent.step(prompt)
|
||||
try:
|
||||
response = self.generator_agent.step(
|
||||
prompt, response_format=AgentResponse
|
||||
)
|
||||
agent_response = response.msgs[0].parsed.score # type: ignore [union-attr]
|
||||
score = agent_response
|
||||
|
||||
# Exit early if we find a very good solution (score > 0.9)
|
||||
if score > 0.9:
|
||||
logger.info(
|
||||
"Found excellent solution with score %.2f. "
|
||||
"Stopping search early.",
|
||||
score,
|
||||
)
|
||||
return current_solution
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_solution = current_solution
|
||||
|
||||
logger.info(
|
||||
"Current search progress: %d/%d, best score: %.2f",
|
||||
i + 1,
|
||||
self.search_limit,
|
||||
best_score,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error parsing agent response: %s", str(e))
|
||||
continue
|
||||
|
||||
# 3. If the answer is not completely correct,
|
||||
# use binary search to locate the error
|
||||
error_pos = self.binary_search_error(question, best_solution)
|
||||
|
||||
# If no errors found (error_pos == -1), return the current solution
|
||||
if error_pos == -1:
|
||||
logger.info("No specific errors found in the solution")
|
||||
return best_solution
|
||||
|
||||
# 4. Generate new solution based on correct part
|
||||
correct_part = '. '.join(best_solution.split('. ')[:error_pos]) + '.'
|
||||
final_solution = self.get_answer(question, correct_part)
|
||||
self.solution_tree[question] = {
|
||||
"solution": final_solution,
|
||||
"error_position": error_pos,
|
||||
}
|
||||
return final_solution
|
||||
|
||||
def import_qa_from_json(self, data: Union[str, Dict[str, str]]) -> bool:
|
||||
r"""Import question and answer data from either a JSON file or a
|
||||
dictionary.
|
||||
|
||||
Args:
|
||||
data (Union[str, Dict[str, str]]): Either a path to a JSON file
|
||||
containing QA pairs or a dictionary of question-answer pairs.
|
||||
If a string is provided, it's treated as a file path.
|
||||
The expected format is:
|
||||
{"question1": "answer1",
|
||||
"question2": "answer2",
|
||||
...}
|
||||
|
||||
Returns:
|
||||
bool: True if import was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
logger.info("Loading QA pairs from file: %s", data)
|
||||
with open(data, 'r', encoding='utf-8') as f:
|
||||
qa_data = json.load(f)
|
||||
else:
|
||||
logger.info("Loading QA pairs from provided dictionary")
|
||||
qa_data = data
|
||||
|
||||
# Validate the data format
|
||||
if not isinstance(qa_data, dict):
|
||||
logger.error("Invalid data format: expected dictionary")
|
||||
return False
|
||||
|
||||
# Update golden answers
|
||||
self.golden_answers.update(qa_data)
|
||||
logger.info("Successfully imported %d QA pairs", len(qa_data))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error importing QA data: %s", str(e))
|
||||
return False
|
||||
|
||||
def export_solutions(self, filepath: str = 'solutions.json') -> None:
|
||||
r"""Export the solution process and results to a JSON file.
|
||||
Exports the solution tree, golden answers,
|
||||
and export timestamp to a JSON file.
|
||||
The exported data includes:
|
||||
- solutions: The solution tree
|
||||
with intermediate steps
|
||||
- golden_answers: The reference answers used for verification
|
||||
- export_time: ISO format timestamp of the export
|
||||
|
||||
Args:
|
||||
filepath (str, optional): Path where the JSON file will be saved.
|
||||
(default::obj:`'solutions.json'`)
|
||||
|
||||
Returns:
|
||||
None: The method writes to a file and logs the result but does not
|
||||
return any value.
|
||||
"""
|
||||
export_data = {
|
||||
"solutions": self.solution_tree,
|
||||
"golden_answers": self.golden_answers,
|
||||
"export_time": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Solutions exported successfully to {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting solutions: {e!s}")
|
||||
20
camel/datagen/evol_instruct/__init__.py
Normal file
20
camel/datagen/evol_instruct/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .evol_instruct import EvolInstructPipeline
|
||||
|
||||
__all__ = [
|
||||
'EvolInstructPipeline',
|
||||
'MathEvolInstructTemplates',
|
||||
]
|
||||
424
camel/datagen/evol_instruct/evol_instruct.py
Normal file
424
camel/datagen/evol_instruct/evol_instruct.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from math import ceil
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.datagen.evol_instruct.scorer import BaseScorer, GeneralScorer
|
||||
from camel.datagen.evol_instruct.templates import EvolInstructTemplates
|
||||
from camel.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EvolInstructPipeline:
|
||||
r"""Pipeline for evolving prompts using the Evol-Instruct methodology.
|
||||
|
||||
Supports custom templates defining evolution strategies and methods. The
|
||||
pipeline leverages language models to iteratively refine prompts through
|
||||
specified evolution strategies.
|
||||
|
||||
Args:
|
||||
templates (Type[EvolInstructTemplates]): Template class containing
|
||||
evolution strategy and method definitions. Must provide
|
||||
`EVOL_METHODS` and `STRATEGY` attributes.
|
||||
(default: :obj:`EvolInstructTemplates`)
|
||||
agent (Optional[ChatAgent]): Chat agent instance for LLM interaction.
|
||||
If :obj:`None`, initializes with a default ChatAgent.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
templates: Type = EvolInstructTemplates,
|
||||
agent: Optional[ChatAgent] = None,
|
||||
) -> None:
|
||||
r"""Initialize pipeline with templates and language model agent.
|
||||
|
||||
Args:
|
||||
templates (Type[EvolInstructTemplates]): Template class containing
|
||||
evolution strategy configurations.
|
||||
(default: :obj:`EvolInstructTemplates`)
|
||||
agent (Optional[ChatAgent]): Preconfigured chat agent instance.
|
||||
Creates a default ChatAgent if not provided.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
self.templates = templates
|
||||
self.agent = agent or ChatAgent()
|
||||
|
||||
def _resolve_evolution_method(self, method_key: str) -> str:
|
||||
r"""Resolve evolution method key to concrete implementation.
|
||||
|
||||
Args:
|
||||
method_key (str): Input method identifier. Can be:
|
||||
- Direct method key from templates.EVOL_METHODS
|
||||
- Strategy name from templates.STRATEGY keys
|
||||
|
||||
Returns:
|
||||
str: Resolved method key from EVOL_METHODS
|
||||
"""
|
||||
if method_key in self.templates.EVOL_METHODS:
|
||||
return method_key
|
||||
if method_key.upper() in self.templates.STRATEGY:
|
||||
strategy = self.templates.STRATEGY[method_key.upper()]
|
||||
strategy_methods = strategy["methods"]
|
||||
return random.choice(strategy_methods)
|
||||
|
||||
logger.warning(
|
||||
f"Invalid evolution method: {method_key}. "
|
||||
f"Using random selection."
|
||||
)
|
||||
return random.choice(list(self.templates.EVOL_METHODS))
|
||||
|
||||
def _get_evolution_methods(
|
||||
self,
|
||||
method: Union[str, List[str]],
|
||||
num_generations: int = 2,
|
||||
) -> List[str]:
|
||||
r"""Get list of evolution methods based on input specification.
|
||||
|
||||
Args:
|
||||
method (Union[str, List[str]]): Specification for method selection.
|
||||
Can be:
|
||||
- Strategy name for methods from that strategy
|
||||
- Specific method name
|
||||
- List of method specifications
|
||||
num_generations (int): Number of methods to return.
|
||||
|
||||
Returns:
|
||||
List[str]: List of resolved method names
|
||||
"""
|
||||
candidate_methods = []
|
||||
|
||||
if isinstance(method, list):
|
||||
for method_spec in method:
|
||||
candidate_methods.append(
|
||||
self._resolve_evolution_method(method_spec)
|
||||
)
|
||||
elif isinstance(method, str):
|
||||
if method.upper() in self.templates.STRATEGY:
|
||||
strategy = self.templates.STRATEGY[method.upper()]
|
||||
candidate_methods = strategy["methods"]
|
||||
else:
|
||||
candidate_methods = [self._resolve_evolution_method(method)]
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_candidates = []
|
||||
for method_name in candidate_methods:
|
||||
if method_name not in unique_candidates:
|
||||
unique_candidates.append(method_name)
|
||||
|
||||
if len(unique_candidates) >= num_generations:
|
||||
methods = random.sample(unique_candidates, num_generations)
|
||||
else:
|
||||
methods = unique_candidates.copy()
|
||||
while len(methods) < num_generations:
|
||||
methods.append(random.choice(unique_candidates))
|
||||
|
||||
return methods
|
||||
|
||||
def _generate_single_evolution(
|
||||
self,
|
||||
prompt: str,
|
||||
method: str,
|
||||
return_method: bool = False,
|
||||
) -> Tuple[str, str]:
|
||||
r"""Generate a single evolved prompt from a seed prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The seed prompt to evolve.
|
||||
method (str): The evolution method key to use.
|
||||
return_method (bool): If True, returns method along with prompt.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Evolved prompt and method
|
||||
"""
|
||||
resolved_method = self._resolve_evolution_method(method)
|
||||
|
||||
# Find strategy containing the resolved method
|
||||
strategy_key = None
|
||||
for strategy, group in self.templates.STRATEGY.items():
|
||||
if resolved_method in group["methods"]:
|
||||
strategy_key = strategy
|
||||
break
|
||||
|
||||
if strategy_key is None:
|
||||
strategy_key = random.choice(list(self.templates.STRATEGY.keys()))
|
||||
|
||||
strategy = self.templates.STRATEGY[strategy_key]
|
||||
instruction_template = strategy["meta_instruction"]
|
||||
instruction = instruction_template.format(
|
||||
method=self.templates.EVOL_METHODS.get(
|
||||
resolved_method,
|
||||
random.choice(list(self.templates.EVOL_METHODS.values())),
|
||||
),
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
self.agent.reset()
|
||||
response = self.agent.step(instruction)
|
||||
evolved_prompt = response.msgs[0].content.strip()
|
||||
|
||||
if return_method:
|
||||
return (evolved_prompt, resolved_method)
|
||||
else:
|
||||
return (evolved_prompt, "")
|
||||
|
||||
def _generate_multiple_evolutions(
|
||||
self,
|
||||
prompt: str,
|
||||
method: Union[str, List[str]],
|
||||
num_generations: int = 2,
|
||||
keep_original: bool = True,
|
||||
num_threads: int = 10,
|
||||
) -> List[Tuple[str, str]]:
|
||||
r"""Generate multiple evolved versions of a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Seed prompt to evolve.
|
||||
method (Union[str, List[str]]): Evolution method specification.
|
||||
num_generations (int): Candidates to generate per iteration.
|
||||
keep_original (bool): Whether to keep the original prompt.
|
||||
num_threads (int): Number of threads for parallel processing.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: List of (evolved_prompt, method) pairs
|
||||
"""
|
||||
results = [(prompt, "original")] if keep_original else []
|
||||
|
||||
if isinstance(method, list) and len(method) == num_generations:
|
||||
candidate_methods = method
|
||||
else:
|
||||
candidate_methods = self._get_evolution_methods(
|
||||
method=method, num_generations=num_generations
|
||||
)
|
||||
|
||||
def _process_single_method(method_name: str) -> Tuple[str, str]:
|
||||
return self._generate_single_evolution(
|
||||
prompt, method_name, return_method=True
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
evolved_results = list(
|
||||
executor.map(_process_single_method, candidate_methods)
|
||||
)
|
||||
|
||||
results.extend(evolved_results)
|
||||
return results
|
||||
|
||||
def _generate_iterative_evolutions(
|
||||
self,
|
||||
prompt: str,
|
||||
evolution_spec: Union[str, List[Union[str, List[str]]]],
|
||||
num_generations: int = 2,
|
||||
num_iterations: Optional[int] = None,
|
||||
keep_original: bool = True,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
num_threads: int = 10,
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
r"""Generate iterative evolutions of a prompt with scoring.
|
||||
|
||||
Args:
|
||||
prompt (str): Seed prompt to evolve.
|
||||
evolution_spec (Union[str, List[Union[str, List[str]]]]):
|
||||
Evolution method specification.
|
||||
If a list is provided and num_iterations is None, then
|
||||
num_iterations is set to the length of the list.
|
||||
num_generations (int): Candidates to generate per iteration.
|
||||
num_iterations (Optional[int]): Number of evolution iterations.
|
||||
Defaults to the length of evolution_spec.
|
||||
keep_original (bool): Include original prompt in results.
|
||||
scorer (Optional[BaseScorer]): Scoring model for candidate.
|
||||
num_threads (int): Number of threads for parallel processing.
|
||||
|
||||
Returns:
|
||||
Dict[int, List[Dict[str, Any]]]: Evolution results per iteration,
|
||||
where each candidate is represented as a dict with keys:
|
||||
"instruction", "method", and "scores".
|
||||
"""
|
||||
if num_iterations is None:
|
||||
if isinstance(evolution_spec, list):
|
||||
num_iterations = len(evolution_spec)
|
||||
else:
|
||||
num_iterations = 1
|
||||
|
||||
results = {}
|
||||
current_prompt = prompt
|
||||
scorer = scorer or GeneralScorer()
|
||||
|
||||
for iteration in range(num_iterations):
|
||||
if isinstance(evolution_spec, list):
|
||||
if iteration < len(evolution_spec):
|
||||
iteration_spec = evolution_spec[iteration]
|
||||
else:
|
||||
iteration_spec = evolution_spec[-1]
|
||||
else:
|
||||
iteration_spec = evolution_spec
|
||||
|
||||
batch_results = self._generate_multiple_evolutions(
|
||||
prompt=current_prompt,
|
||||
method=iteration_spec,
|
||||
num_generations=num_generations,
|
||||
keep_original=False,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
scored_results = []
|
||||
for candidate, method_used in batch_results:
|
||||
scores = scorer.score(current_prompt, candidate)
|
||||
scored_results.append(
|
||||
{
|
||||
"instruction": candidate,
|
||||
"method": method_used,
|
||||
"scores": scores,
|
||||
}
|
||||
)
|
||||
|
||||
best_index = max(
|
||||
range(len(scored_results)),
|
||||
key=lambda i: sum(
|
||||
cast(Dict[str, int], scored_results[i]["scores"]).values()
|
||||
),
|
||||
)
|
||||
|
||||
best_candidate = cast(
|
||||
str, scored_results[best_index]["instruction"]
|
||||
)
|
||||
|
||||
if keep_original:
|
||||
results[iteration] = [
|
||||
{
|
||||
"instruction": current_prompt,
|
||||
"method": "original",
|
||||
"scores": {},
|
||||
},
|
||||
*scored_results,
|
||||
]
|
||||
else:
|
||||
results[iteration] = scored_results
|
||||
|
||||
current_prompt = best_candidate
|
||||
|
||||
return results
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
evolution_spec: Union[str, List[Union[str, List[str]]]],
|
||||
num_generations: int = 2,
|
||||
num_iterations: Optional[int] = None,
|
||||
keep_original: bool = True,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
num_chunks: int = 1,
|
||||
retry_limit: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
num_threads: int = 10,
|
||||
) -> List[Dict[int, List[Dict[str, Any]]]]:
|
||||
r"""Evolve a batch of prompts through iterative refinement.
|
||||
|
||||
Args:
|
||||
prompts (List[str]): Seed prompts to evolve.
|
||||
evolution_spec (Union[str, List[Union[str, List[str]]]]):
|
||||
Evolution method specification.
|
||||
If a list is provided and num_iterations is None, then
|
||||
num_iterations is set to the length of the list.
|
||||
num_generations (int): Candidates to generate per iteration.
|
||||
num_iterations (Optional[int]): Number of evolution iterations.
|
||||
Defaults to the length of evolution_spec.
|
||||
keep_original (bool): Include original prompts in results.
|
||||
scorer (Optional[BaseScorer]): Scoring model for candidate.
|
||||
num_chunks (int): Number of parallel processing chunks.
|
||||
retry_limit (int): Max retries for failed generations.
|
||||
retry_delay (float): Delay between retries in seconds.
|
||||
num_threads (int): Number of threads for parallel processing.
|
||||
|
||||
Returns:
|
||||
List[Dict[int, List[Dict[str, Any]]]]: Evolution results.
|
||||
"""
|
||||
if num_iterations is None:
|
||||
if isinstance(evolution_spec, list):
|
||||
num_iterations = len(evolution_spec)
|
||||
else:
|
||||
num_iterations = 1
|
||||
|
||||
evolution_plan: List[List[List[str]]] = []
|
||||
for _ in prompts:
|
||||
prompt_plan = []
|
||||
for iteration in range(num_iterations):
|
||||
if isinstance(evolution_spec, list):
|
||||
if iteration < len(evolution_spec):
|
||||
raw_spec = evolution_spec[iteration]
|
||||
else:
|
||||
raw_spec = evolution_spec[-1]
|
||||
else:
|
||||
raw_spec = evolution_spec
|
||||
prompt_plan.append(
|
||||
self._get_evolution_methods(raw_spec, num_generations)
|
||||
)
|
||||
evolution_plan.append(prompt_plan)
|
||||
|
||||
def _process_prompt(
|
||||
args: Tuple[str, List[List[str]]],
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
prompt, methods = args
|
||||
retries = 0
|
||||
while retries <= retry_limit:
|
||||
try:
|
||||
return self._generate_iterative_evolutions(
|
||||
prompt=prompt,
|
||||
evolution_spec=evolution_spec,
|
||||
num_generations=num_generations,
|
||||
num_iterations=num_iterations,
|
||||
keep_original=keep_original,
|
||||
scorer=scorer,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
if retries <= retry_limit:
|
||||
logger.warning(
|
||||
f"Error processing prompt "
|
||||
f"(attempt {retries}/{retry_limit}): {e!s}"
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Failed to process prompt.")
|
||||
return {}
|
||||
|
||||
raise RuntimeError("_process_prompt() did not return.")
|
||||
|
||||
num_chunks = max(1, min(num_chunks, len(prompts)))
|
||||
chunk_size = ceil(len(prompts) / num_chunks)
|
||||
results = []
|
||||
|
||||
for chunk_idx in range(0, len(prompts), chunk_size):
|
||||
chunk = prompts[chunk_idx : chunk_idx + chunk_size]
|
||||
plan_chunk = evolution_plan[chunk_idx : chunk_idx + chunk_size]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
chunk_results = list(
|
||||
tqdm(
|
||||
executor.map(_process_prompt, zip(chunk, plan_chunk)),
|
||||
total=len(chunk),
|
||||
)
|
||||
)
|
||||
results.extend(chunk_results)
|
||||
|
||||
return results
|
||||
166
camel/datagen/evol_instruct/scorer.py
Normal file
166
camel/datagen/evol_instruct/scorer.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
@abstractmethod
|
||||
def score(
|
||||
self, reference_prompt: str, candidate_prompt: str
|
||||
) -> Dict[str, int]:
|
||||
r"""Compare a candidate prompt against a reference prompt and
|
||||
return a tuple of scores. The higher the score, the better.
|
||||
For example, (diversity, difficulty, feasibility).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MathScorer(BaseScorer):
|
||||
def __init__(self, agent: Optional[ChatAgent] = None):
|
||||
self.system_msg = (
|
||||
"You are an evaluator for math problems. Your task is to compare "
|
||||
"a new math problem against a reference math problem, and rate it "
|
||||
"in **four dimensions**, each scored from 1 to 5.\n\n"
|
||||
"1. Diversity (1-5): How novel is the new problem compared to the "
|
||||
"reference? 1 = very similar, 5 = completely different.\n"
|
||||
"2. Difficulty (1-5): Rate the relative difficulty compared to the"
|
||||
" reference problem. 1 = much less difficult, "
|
||||
"3 = similar difficulty, 5 = much more difficult.\n"
|
||||
"3. Validity (1-5): How well-defined and sound is the problem?"
|
||||
"1 = very vague or flawed, 5 = very clear and rigorous.\n"
|
||||
"4. Solvability (1-5): How likely is the problem solvable using "
|
||||
"standard math techniques? 1 = very unsolvable or ambiguous, "
|
||||
"5 = very clearly solvable.\n\n"
|
||||
"Respond with a JSON object like: "
|
||||
"{ \"diversity\": ..., \"difficulty\": ..., "
|
||||
"\"validity\": ..., \"solvability\": ... }"
|
||||
)
|
||||
self.agent = agent or ChatAgent(self.system_msg)
|
||||
|
||||
class MathScoreSchema(BaseModel):
|
||||
diversity: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Score for the diversity of the math problem "
|
||||
"compared to the reference"
|
||||
),
|
||||
)
|
||||
difficulty: int = Field(
|
||||
..., description="Score for the relative difficulty"
|
||||
)
|
||||
validity: int = Field(
|
||||
...,
|
||||
description="Score for how well-defined and sound the problem is",
|
||||
)
|
||||
solvability: int = Field(
|
||||
...,
|
||||
description="Score for the solvability of the problem",
|
||||
)
|
||||
|
||||
def score(
|
||||
self, reference_problem: str, new_problem: str
|
||||
) -> Dict[str, int]:
|
||||
r"""Evaluates the new math problem relative to the reference math
|
||||
problem.
|
||||
|
||||
Args:
|
||||
reference_problem (str): The reference math problem.
|
||||
new_problem (str): The new or evolved math problem.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary with scores for diversity, difficulty,
|
||||
validity, and solvability.
|
||||
"""
|
||||
query = (
|
||||
f"Reference problem:\n{reference_problem}\n\n"
|
||||
f"New problem:\n{new_problem}\n\n"
|
||||
"Provide scores in JSON format."
|
||||
)
|
||||
response = self.agent.step(query, response_format=self.MathScoreSchema)
|
||||
score_data = json.loads(response.msg.content)
|
||||
return score_data
|
||||
|
||||
|
||||
class GeneralScorer(BaseScorer):
|
||||
def __init__(self, agent: Optional[ChatAgent] = None):
|
||||
self.system_msg = (
|
||||
"You are an evaluator for problems in various domains. Your task "
|
||||
"is to compare a new problem against a reference problem, and rate"
|
||||
" it in **three dimensions**, each scored from 1 to 5.\n\n"
|
||||
"1. Diversity (1-5): How novel is the new problem compared to the "
|
||||
"reference? 1 = very similar, 5 = completely different.\n"
|
||||
"2. Complexity (1-5): Relative to the reference problem. "
|
||||
"1 = much less complex, 3 = similar complexity, "
|
||||
"5 = much more complex.\n"
|
||||
"3. Validity (1-5): How well-defined, meaningful, the problem is."
|
||||
"1 = vague/flawed, 5 = precise and fully meaningful.\n"
|
||||
"Respond with a JSON object like: "
|
||||
"{ \"diversity\": ..., \"complexity\": ..., \"validity\": ... }"
|
||||
)
|
||||
self.agent = agent or ChatAgent(self.system_msg)
|
||||
|
||||
class GeneralScoreSchema(BaseModel):
|
||||
diversity: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Score for the diversity of the problem "
|
||||
"compared to the reference."
|
||||
),
|
||||
)
|
||||
complexity: int = Field(
|
||||
...,
|
||||
description=("Score for the relative complexity of the problem."),
|
||||
)
|
||||
validity: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Score estimating the likelihood that the problem is "
|
||||
"well-defined."
|
||||
),
|
||||
)
|
||||
|
||||
def score(
|
||||
self, reference_problem: str, new_problem: str
|
||||
) -> Dict[str, int]:
|
||||
r"""Evaluates the new problem against the reference problem using
|
||||
structured scoring.
|
||||
|
||||
Args:
|
||||
reference_problem (str): The original problem.
|
||||
new_problem (str): The evolved or new problem.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary with scores for diversity, complexity,
|
||||
and validity.
|
||||
"""
|
||||
query = (
|
||||
f"Reference problem:\n{reference_problem}\n\n"
|
||||
f"New problem:\n{new_problem}\n\n"
|
||||
"Provide scores in JSON format."
|
||||
)
|
||||
response = self.agent.step(
|
||||
query, response_format=self.GeneralScoreSchema
|
||||
)
|
||||
score_data = json.loads(response.msg.content)
|
||||
return score_data
|
||||
268
camel/datagen/evol_instruct/templates.py
Normal file
268
camel/datagen/evol_instruct/templates.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
@dataclass(frozen=True)
|
||||
class BaseEvolInstructTemplates(ABC):
|
||||
r"""Abstract base class for evolution instruction templates.
|
||||
|
||||
This class defines a required structure for prompt transformation templates
|
||||
- `EVOL_METHODS`: A dictionary mapping method keys to their descriptions.
|
||||
- `STRATEGY`: A dictionary defining strategies and associated methods.
|
||||
|
||||
Subclasses should define concrete templates for specific domains.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def EVOL_METHODS(self) -> Dict[str, str]:
|
||||
r"""A dictionary mapping evolution method keys to their descriptions."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def STRATEGY(self) -> Dict[str, Dict[str, Union[str, List[str]]]]:
|
||||
r"""A dictionary defining strategies and their corresponding methods."""
|
||||
pass
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
@dataclass(frozen=True)
|
||||
class EvolInstructTemplates(BaseEvolInstructTemplates):
|
||||
r"""Contains templates for EvolInstruct prompt transformations.
|
||||
|
||||
References:
|
||||
- WizardLM: Empowering Large Language Models to Follow Complex
|
||||
Instructions
|
||||
https://arxiv.org/pdf/2304.12244
|
||||
- eva: Evolving Alignment via Asymmetric Self-Play
|
||||
https://arxiv.org/abs/2411.00062
|
||||
"""
|
||||
|
||||
# High-level instructions on in-depth/in-breadth evolving
|
||||
INST_IN_DEPTH = (
|
||||
"Please act as an expert Prompt Creator.\n"
|
||||
"Your objective is to rewrite a given prompt into a more complex "
|
||||
"version to make those large language models (e.g., gemini) a bit "
|
||||
"harder to handle.\n"
|
||||
"But the rewritten prompt must be reasonable and must be understood "
|
||||
"and responded by humans.\n"
|
||||
"Your rewriting cannot omit the non-text parts such as the table and "
|
||||
"code in #Given Prompt#, if there is any."
|
||||
"You should try your best not to make the #Rewritten Prompt# become "
|
||||
"verbose, "
|
||||
"The #Rewritten Prompt# should be roughly the similar length or a "
|
||||
"little bit more than that of #Given Prompt#.\n"
|
||||
"The #Rewritten Prompt# must sound like a real human user's prompt; "
|
||||
"DON'T make it like sound machine-generated."
|
||||
"Specifically, you SHOULD complicate the given prompt using the "
|
||||
"following method: "
|
||||
"\n{method}\n"
|
||||
"The rewritten prompt should reflect meaningful changes across its "
|
||||
"structure, ensuring the entire sentence feels sufficiently different "
|
||||
"from the original. "
|
||||
"Again, make sure the rewritten prompt is more CHALLENGING."
|
||||
"Respond with your rewritten prompt directly. "
|
||||
"#Given Prompt#:\n{prompt}\n"
|
||||
"#Rewritten Prompt#:\n"
|
||||
).lstrip()
|
||||
|
||||
INST_IN_BREADTH = (
|
||||
"Please act as an expert Prompt Creator.\n"
|
||||
"Your objective is to generate a brand-new prompt based on the #Given "
|
||||
"Prompt#. "
|
||||
"The purpose of this task is to promote diversity and generality of "
|
||||
"training prompts for language models, helping it practice with "
|
||||
"varied challenges and perspectives.\n"
|
||||
"The LENGTH and complexity of the #Created Prompt# should be similar "
|
||||
"to that of the #Given Prompt#.\n"
|
||||
"The #Created Prompt# must be reasonable, interpretable, and solvable "
|
||||
"by humans.\n"
|
||||
"The #Created Prompt# must sound like a real human user's prompt; "
|
||||
"DON'T make it sound like machine-generated."
|
||||
"Follow the method described below to guide your creation:\n"
|
||||
"{method}\n"
|
||||
"The created prompt should reflect meaningful changes across its "
|
||||
"structure, ensuring the entire sentence feels sufficiently different "
|
||||
"from the original. "
|
||||
"Respond with your created prompt directly.\n"
|
||||
"#Given Prompt#:\n{prompt}\n"
|
||||
"#Created Prompt#:\n"
|
||||
).lstrip()
|
||||
|
||||
# Sub-method instructions (following the eva paper setting)
|
||||
IN_BREADTH_KEYS = [
|
||||
'persona',
|
||||
'shift-in',
|
||||
'shift-out',
|
||||
'mix',
|
||||
'abstract',
|
||||
]
|
||||
|
||||
IN_DEPTH_KEYS = [
|
||||
'constraints',
|
||||
'deepening',
|
||||
'concretizing',
|
||||
'reasoning',
|
||||
'expansion',
|
||||
]
|
||||
|
||||
STRATEGY = {
|
||||
"IN-DEPTH": {
|
||||
'meta_instruction': INST_IN_DEPTH,
|
||||
'methods': IN_DEPTH_KEYS,
|
||||
},
|
||||
"IN-BREADTH": {
|
||||
'meta_instruction': INST_IN_BREADTH,
|
||||
'methods': IN_BREADTH_KEYS,
|
||||
},
|
||||
}
|
||||
|
||||
EVOL_METHODS = {
|
||||
"persona": (
|
||||
"Reframe the #Given Prompt# as if written by a user with a "
|
||||
"completely different persona, background, or expertise. Adjust "
|
||||
"the tone, style, phrasing, or anything you feel proper to "
|
||||
"reflect this change. The changes should make the prompt feel "
|
||||
"like it was authored by someone entirely new."
|
||||
),
|
||||
"shift-in": (
|
||||
"Shift the high-level idea of the #Given Prompt# to explore a "
|
||||
"different subdomain or context within the same domain. Ensure "
|
||||
"the new topic still challenges the model to reason or provide "
|
||||
"knowledge relevant to the domain."
|
||||
),
|
||||
"shift-out": (
|
||||
"Shift the high-level idea of the #Given Prompt# to a completely "
|
||||
"different topic in a different setting. The new topic may "
|
||||
"challenge the model with similar reasoning or contextual "
|
||||
"understanding but in a novel way."
|
||||
),
|
||||
"mix": (
|
||||
"Combine the high-level concept of the #Given Prompt# with "
|
||||
"elements from a different domain. Introduce novel scenarios or "
|
||||
"contexts to create diversity while maintaining relevance to the "
|
||||
"original idea."
|
||||
),
|
||||
"abstract": (
|
||||
"Turn the #Given Prompt# into a more abstract or generalized "
|
||||
"version, removing specific details while preserving its intent. "
|
||||
"Ensure the new prompt encourages broader, principle-driven "
|
||||
"reasoning."
|
||||
),
|
||||
"constraints": (
|
||||
"Add one or more significant constraints or requirements into the "
|
||||
"'#Given Prompt#'. The added constraints must meaningfully alter "
|
||||
"how the model would respond. For example, specify additional "
|
||||
"rules, contexts, or limitations that demand creative adjustments."
|
||||
),
|
||||
"deepening": (
|
||||
"If the #Given Prompt# contains inquiries about certain issues, "
|
||||
"increase the depth and breadth of the inquiry. Make the question "
|
||||
"require a more detailed, multi-layered, or comprehensive response"
|
||||
". For instance, break the problem into sub-problems or require "
|
||||
"connections between unrelated concepts."
|
||||
),
|
||||
"concretizing": (
|
||||
"Replace general concepts in the #Given Prompt# with more specific"
|
||||
" and detailed concepts. Ensure that the change makes the problem "
|
||||
"more defined and concrete, leaving less room for ambiguity. For "
|
||||
"example, replace 'a device' with 'a wearable fitness tracker "
|
||||
"with GPS'."
|
||||
),
|
||||
"reasoning": (
|
||||
"Add one or more reasoning steps into the '#Given Prompt#'. "
|
||||
"Explicitly rewrite it to demand multi-step reasoning or justify "
|
||||
"intermediate steps in the solution. For instance, if the original"
|
||||
" prompt is a simple query, make the response require a "
|
||||
"step-by-step breakdown of logic or calculations."
|
||||
),
|
||||
"expansion": (
|
||||
"Expand the #Given Prompt# by including additional perspectives, "
|
||||
"domains, or layers of complexity. For example, if the original "
|
||||
"prompt focuses on a single scenario, add related scenarios or ask"
|
||||
" the model to compare different situations."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
@dataclass(frozen=True)
|
||||
class MathEvolInstructTemplates(BaseEvolInstructTemplates):
|
||||
r"""Contains templates for MathEvolInstruct prompt transformations."""
|
||||
|
||||
# Meta-instructions for in-depth evolving
|
||||
INST_IN_DEPTH = (
|
||||
"Please act as a math expert. Your objective is to create a new math "
|
||||
"problem that is more challenging yet concise than the given math "
|
||||
"problem. Ensure that the mathematical content (including any "
|
||||
"equations or figures) is preserved, and rephrase the problem to "
|
||||
"increase its complexity and depth. The generated problem should be "
|
||||
"clearly stated, strictly mathematical, and suitable for solving with "
|
||||
"symbolic computation (e.g., using sympy). You will be given a method "
|
||||
"to guide your creation. Make sure to follow the method strictly. "
|
||||
"Consolidate any multiple parts into one integrated question that "
|
||||
"ask for one definitive answer. Respond with your generated problem "
|
||||
"directly. "
|
||||
"#Original Problem#:\n{prompt}\n"
|
||||
"#Generated Problem#:\n"
|
||||
).lstrip()
|
||||
|
||||
EVOL_METHODS = {
|
||||
"constraints": (
|
||||
"Add one or more significant constraints or requirements into the "
|
||||
"'#Given Prompt#'. The added constraints must meaningfully alter "
|
||||
"how the model would respond. For example, specify additional "
|
||||
"rules, contexts, or limitations that demand creative adjustments."
|
||||
),
|
||||
"deepening": (
|
||||
"Increase the difficulty of the #Given Prompt# by integrating "
|
||||
"additional layers of reasoning and rigor. Refine the problem so "
|
||||
"that all added difficulty is consolidated into a single coherent "
|
||||
"question requiring one final answer, avoiding fragmentation into "
|
||||
"multiple sub-problems."
|
||||
),
|
||||
"expansion": (
|
||||
"Expand the #Given Prompt# by incorporating additional "
|
||||
"perspectives or layers of complexity into the problem statement. "
|
||||
"Ensure that the revised problem remains a single, unified "
|
||||
"question with one final answer, rather than a series of separate "
|
||||
"sub-questions."
|
||||
),
|
||||
"condense": (
|
||||
"Reformulate the given math problem into a well-structured and "
|
||||
"formally stated mathematical question.\n"
|
||||
"- Present the problem in a structured and rigorous mathematical "
|
||||
"format.\n"
|
||||
"- Removing unnecessary instructions, explanations, or hints.\n"
|
||||
"- If the given problem contains several sub-questions, make "
|
||||
"necessary changes to let the problem could be answered with one "
|
||||
"number or expression by removing the sub-questions or combining "
|
||||
"them into one."
|
||||
),
|
||||
}
|
||||
|
||||
IN_DEPTH_KEYS = ['constraints', 'deepening', 'expansion']
|
||||
|
||||
STRATEGY = {
|
||||
"IN-DEPTH": {
|
||||
'meta_instruction': INST_IN_DEPTH,
|
||||
'methods': IN_DEPTH_KEYS,
|
||||
},
|
||||
}
|
||||
899
camel/datagen/self_improving_cot.py
Normal file
899
camel/datagen/self_improving_cot.py
Normal file
@@ -0,0 +1,899 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.models.reward import BaseRewardModel, Evaluator
|
||||
from camel.utils import BatchProcessor, retry_on_error
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AgentTraceEvaluation(BaseModel):
|
||||
correctness: float
|
||||
clarity: float
|
||||
completeness: float
|
||||
feedback: str
|
||||
|
||||
|
||||
class RewardTraceEvaluation(BaseModel):
|
||||
feedback: str
|
||||
|
||||
def __init__(self, **data):
|
||||
# Allow dynamic score fields while ensuring feedback is present
|
||||
super().__init__(**data)
|
||||
|
||||
class Config:
|
||||
extra = (
|
||||
"allow" # Allow extra fields for different reward model dimensions
|
||||
)
|
||||
|
||||
|
||||
class TraceIteration(BaseModel):
|
||||
iteration: int
|
||||
trace: str
|
||||
evaluation: Union[AgentTraceEvaluation, RewardTraceEvaluation]
|
||||
|
||||
|
||||
class ProblemResult(BaseModel):
|
||||
id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
problem: str
|
||||
solution: Optional[str] = None
|
||||
final_trace: str
|
||||
agent_evaluate_success: Optional[bool] = None
|
||||
boxed_answer_success: bool = False
|
||||
improvement_history: List[TraceIteration]
|
||||
|
||||
|
||||
class SelfImprovingCoTPipeline:
|
||||
r"""Pipeline for generating self-taught reasoning traces
|
||||
using the self-improving methodology.
|
||||
|
||||
This implements the STaR paper's approach of:
|
||||
1. Initial reasoning trace generation
|
||||
2. Self-evaluation
|
||||
3. Feedback-based improvement
|
||||
4. Iterative refinement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reason_agent: ChatAgent,
|
||||
problems: List[Dict],
|
||||
max_iterations: int = 3,
|
||||
score_threshold: Union[float, Dict[str, float]] = 0.7,
|
||||
rejection_sampling_n: Optional[int] = None,
|
||||
evaluate_agent: Optional[ChatAgent] = None,
|
||||
reward_model: Optional[BaseRewardModel] = None,
|
||||
output_path: Optional[str] = None,
|
||||
few_shot_examples: Optional[str] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
solution_pattern: str = r'\\boxed{(.*?)}',
|
||||
trace_pattern: Optional[str] = None,
|
||||
):
|
||||
r"""Initialize the self-improving cot pipeline.
|
||||
|
||||
Args:
|
||||
reason_agent (ChatAgent): The chat agent used for generating and
|
||||
improving reasoning traces.
|
||||
problems (List[Dict]): List of problem dictionaries to process.
|
||||
max_iterations (int, optional): Maximum number of improvement
|
||||
iterations. If set to `0`, the pipeline will generate an
|
||||
initial trace without any improvement iterations.
|
||||
(default: :obj:`3`)
|
||||
score_threshold (Union[float, Dict[str, float]], optional):
|
||||
Quality threshold. Can be either a single float value applied
|
||||
to average score, or a dictionary mapping score dimensions to
|
||||
their thresholds. For example: {"correctness": 0.8,
|
||||
"coherence": 0.7}. If using reward model and threshold for a
|
||||
dimension is not specified, will use the default value 0.7.
|
||||
(default: :obj:`0.7`)
|
||||
rejection_sampling_n (int, optional): Specifies the number of
|
||||
samples to be drawn using the rejection sampling
|
||||
method, where samples are accepted or rejected based on
|
||||
a predefined condition to achieve a desired distribution.
|
||||
(default: :obj: `None`)
|
||||
evaluate_agent (Optional[ChatAgent]): The chat agent used for
|
||||
evaluating reasoning traces. (default: :obj:`None`)
|
||||
reward_model (BaseRewardModel, optional): Model used to evaluate
|
||||
reasoning traces. If `None`, uses Agent self-evaluation.
|
||||
(default: :obj:`None`)
|
||||
output_path (str, optional): Output path for saving traces. If
|
||||
`None`, results will only be returned without saving to file.
|
||||
(default: :obj:`None`)
|
||||
few_shot_examples (str, optional): Examples to use for few-shot
|
||||
generation. (default: :obj:`None`)
|
||||
batch_size (int, optional): Batch size for parallel processing.
|
||||
(default: :obj:`None`)
|
||||
max_workers (int, optional): Maximum number of worker threads.
|
||||
(default: :obj:`None`)
|
||||
solution_pattern (str, optional): Regular expression pattern with
|
||||
one capture group to extract answers from solution text.
|
||||
(default: :obj:`r'\\boxed{(.*?)}'`)
|
||||
trace_pattern (str, optional): Regular expression pattern with one
|
||||
capture group to extract answers from trace text. If `None`,
|
||||
uses the same pattern as solution_pattern.
|
||||
(default: :obj:`None`)
|
||||
"""
|
||||
self.reason_agent = reason_agent
|
||||
self.evaluate_agent = evaluate_agent
|
||||
self.problems = problems
|
||||
self.output_path = output_path
|
||||
self.max_iterations = max_iterations
|
||||
self.score_threshold = score_threshold
|
||||
self.rejection_sampling_n = rejection_sampling_n
|
||||
self.reward_model = reward_model
|
||||
self.evaluator = (
|
||||
Evaluator(reward_model=reward_model) if reward_model else None
|
||||
)
|
||||
self.reasoning_traces: List[Dict[str, Any]] = []
|
||||
self.few_shot_examples = few_shot_examples
|
||||
self.batch_processor = BatchProcessor(max_workers, batch_size)
|
||||
self.solution_pattern = solution_pattern
|
||||
self.trace_pattern = (
|
||||
trace_pattern if trace_pattern is not None else solution_pattern
|
||||
)
|
||||
|
||||
# Initialize output file with empty results if path is specified
|
||||
if self.output_path:
|
||||
with open(self.output_path, 'w') as f:
|
||||
json.dump({'traces': []}, f, indent=2, ensure_ascii=False)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def safe_write_json(self, file_path, data):
|
||||
temp_path = file_path + ".tmp"
|
||||
with open(temp_path, "w") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
os.replace(temp_path, file_path)
|
||||
|
||||
def clean_json(self, data):
|
||||
if isinstance(data, dict):
|
||||
return {k: self.clean_json(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self.clean_json(v) for v in data]
|
||||
elif isinstance(data, float) and (
|
||||
math.isnan(data) or math.isinf(data)
|
||||
):
|
||||
return None
|
||||
return data
|
||||
|
||||
async def _batch_process_problems(
|
||||
self, problems: List[Dict], rationalization: bool
|
||||
) -> List[ProblemResult]:
|
||||
r"""Process multiple problems in parallel batches with dynamic sizing.
|
||||
|
||||
Args:
|
||||
problems (List[Dict]): List of problem dictionaries to process.
|
||||
rationalization (bool): Whether to use rationalization.
|
||||
|
||||
Returns:
|
||||
List[ProblemResult]: List of problem results.
|
||||
"""
|
||||
results = []
|
||||
total_problems = len(problems)
|
||||
processed = 0
|
||||
|
||||
while processed < total_problems:
|
||||
batch_size = self.batch_processor.batch_size
|
||||
batch = problems[processed : processed + batch_size]
|
||||
batch_start_time = time.time()
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.batch_processor.max_workers
|
||||
) as executor:
|
||||
# Create futures with rationalization parameter
|
||||
futures = [
|
||||
executor.submit(
|
||||
self.process_problem,
|
||||
problem=problem,
|
||||
rationalization=rationalization,
|
||||
)
|
||||
for problem in batch
|
||||
]
|
||||
|
||||
batch_results = []
|
||||
batch_success = True
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
batch_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing problem: {e}")
|
||||
batch_success = False
|
||||
continue
|
||||
|
||||
results.extend(batch_results)
|
||||
processed += len(batch)
|
||||
|
||||
# Calculate processing time and adjust batch size
|
||||
processing_time = time.time() - batch_start_time
|
||||
self.batch_processor.adjust_batch_size(
|
||||
batch_success, processing_time
|
||||
)
|
||||
|
||||
# Log progress and performance metrics
|
||||
metrics = self.batch_processor.get_performance_metrics()
|
||||
logger.info(
|
||||
f"Processed {processed}/{total_problems} problems "
|
||||
f"(batch size: {batch_size}, workers: "
|
||||
f"{metrics['current_workers']}, "
|
||||
f"CPU: {metrics['current_cpu']:.1f}%, "
|
||||
f"Memory: {metrics['current_memory']:.1f}%)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch processing error: {e}")
|
||||
self.batch_processor.adjust_batch_size(False)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
async def _batch_evaluate_traces(
|
||||
self,
|
||||
problems: List[Dict[str, Any]],
|
||||
traces: List[str],
|
||||
solutions: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Evaluate multiple traces in parallel batches with resource
|
||||
monitoring.
|
||||
|
||||
Args:
|
||||
problems (List[Dict[str, Any]]): List of problem dictionaries
|
||||
traces (List[str]): List of reasoning traces to evaluate
|
||||
solutions (Optional[List[str]]): Optional list of solutions
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of evaluation results
|
||||
"""
|
||||
if solutions is None:
|
||||
solutions = ["null"] * len(problems)
|
||||
|
||||
results = []
|
||||
total_traces = len(traces)
|
||||
processed = 0
|
||||
|
||||
while processed < total_traces:
|
||||
batch_size = self.batch_processor.batch_size
|
||||
problem_batch = problems[processed : processed + batch_size]
|
||||
trace_batch = traces[processed : processed + batch_size]
|
||||
solution_batch = solutions[processed : processed + batch_size]
|
||||
batch_start_time = time.time()
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.batch_processor.max_workers
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self.evaluate_trace,
|
||||
problem=problem["problem"],
|
||||
trace=trace,
|
||||
solution=solution,
|
||||
)
|
||||
for problem, trace, solution in zip(
|
||||
problem_batch, trace_batch, solution_batch
|
||||
)
|
||||
]
|
||||
|
||||
batch_results = []
|
||||
batch_success = True
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
batch_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating trace: {e}")
|
||||
batch_success = False
|
||||
continue
|
||||
|
||||
results.extend(batch_results)
|
||||
processed += len(batch_results)
|
||||
|
||||
# Calculate processing time and adjust batch size
|
||||
processing_time = time.time() - batch_start_time
|
||||
self.batch_processor.adjust_batch_size(
|
||||
batch_success, processing_time
|
||||
)
|
||||
|
||||
# Log progress and performance metrics
|
||||
metrics = self.batch_processor.get_performance_metrics()
|
||||
logger.info(
|
||||
f"Evaluated {processed}/{total_traces} traces "
|
||||
f"(batch size: {batch_size}, workers: "
|
||||
f"{metrics['current_workers']}, "
|
||||
f"avg time: {metrics['avg_processing_time']:.2f}s, "
|
||||
f"error rate: {metrics['error_rate']:.1f}%)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch evaluation error: {e}")
|
||||
self.batch_processor.adjust_batch_size(False)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _check_score_threshold(self, scores: Dict[str, float]) -> bool:
|
||||
r"""Check if scores meet the threshold requirements.
|
||||
|
||||
Args:
|
||||
scores (Dict[str, float]): Dictionary of scores for different
|
||||
dimensions.
|
||||
|
||||
Returns:
|
||||
bool: True if scores meet threshold requirements, False otherwise.
|
||||
"""
|
||||
# If score_threshold is a float, apply it to all dimensions
|
||||
if isinstance(self.score_threshold, float):
|
||||
return all(
|
||||
score >= self.score_threshold for score in scores.values()
|
||||
)
|
||||
|
||||
# If score_threshold is a dict, check each dimension with its threshold
|
||||
# Use 0 as default threshold for unspecified dimensions
|
||||
if isinstance(self.score_threshold, dict):
|
||||
for dim, score in scores.items():
|
||||
threshold = self.score_threshold.get(dim, 0)
|
||||
if score < threshold:
|
||||
return False
|
||||
return True
|
||||
|
||||
# If score_threshold is None or invalid type, pass the check
|
||||
return True
|
||||
|
||||
def _generate_feedback(self, scores: Dict[str, float]) -> str:
|
||||
r"""Generate feedback based on which dimensions need improvement.
|
||||
|
||||
Args:
|
||||
scores (Dict[str, float]): Dictionary of scores for different
|
||||
dimensions.
|
||||
|
||||
Returns:
|
||||
str: Feedback message indicating which dimensions need improvement.
|
||||
"""
|
||||
if isinstance(self.score_threshold, float):
|
||||
below_threshold = [
|
||||
dim
|
||||
for dim, score in scores.items()
|
||||
if score < self.score_threshold
|
||||
]
|
||||
if not below_threshold:
|
||||
return "All dimensions meet the required threshold"
|
||||
dims = ", ".join(below_threshold)
|
||||
return f"Need improvement in: {dims}"
|
||||
|
||||
if isinstance(self.score_threshold, dict):
|
||||
default_threshold = 0
|
||||
below_threshold = [
|
||||
dim
|
||||
for dim, score in scores.items()
|
||||
if score < self.score_threshold.get(dim, default_threshold)
|
||||
]
|
||||
if not below_threshold:
|
||||
return "All dimensions meet their respective thresholds"
|
||||
dims = ", ".join(below_threshold)
|
||||
return f"Need improvement in: {dims}"
|
||||
|
||||
# If no threshold set, just list all dimensions and their scores
|
||||
dims = ", ".join(
|
||||
f"{dim}: {score:.2f}" for dim, score in scores.items()
|
||||
)
|
||||
return f"Current scores - {dims}"
|
||||
|
||||
@retry_on_error()
|
||||
def generate_reasoning_trace(self, problem: str) -> str:
|
||||
r"""Generate initial reasoning trace for a given problem.
|
||||
|
||||
Args:
|
||||
problem (str): The problem text to generate reasoning for.
|
||||
|
||||
Returns:
|
||||
str: Generated reasoning trace.
|
||||
"""
|
||||
self.reason_agent.reset()
|
||||
few_shot_examples = (
|
||||
f"Examples: {self.few_shot_examples}"
|
||||
if self.few_shot_examples
|
||||
else ""
|
||||
)
|
||||
prompt = self.REASONING_TEMPLATE.format(
|
||||
problem=problem, few_shot_examples=few_shot_examples
|
||||
)
|
||||
response = self.reason_agent.step(prompt)
|
||||
return response.msg.content
|
||||
|
||||
@retry_on_error()
|
||||
def evaluate_trace(
|
||||
self, problem: str, trace: str, solution: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
r"""Evaluate the quality of a reasoning trace.
|
||||
|
||||
Args:
|
||||
problem (str): The original problem text to evaluate against.
|
||||
trace (str): The reasoning trace to evaluate.
|
||||
solution (Optional[str]): The solution to the problem, if provided.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Evaluation results containing:
|
||||
- scores: Dict of evaluation dimensions and their scores
|
||||
- feedback: Detailed feedback for improvement
|
||||
|
||||
For Agent self-evaluation, the scores will include:
|
||||
- correctness: Score for logical correctness
|
||||
- clarity: Score for clarity of explanation
|
||||
- completeness: Score for completeness of reasoning
|
||||
|
||||
For reward model evaluation, the scores will depend on
|
||||
the model's evaluation dimensions.
|
||||
"""
|
||||
self.evaluate_agent.reset() # type: ignore[union-attr]
|
||||
if self.evaluator:
|
||||
# Use reward model evaluation
|
||||
messages = [
|
||||
{"role": "user", "content": problem},
|
||||
{"role": "assistant", "content": trace},
|
||||
]
|
||||
scores = self.evaluator.evaluate(messages)
|
||||
|
||||
# For models that return a single score
|
||||
if isinstance(scores, (int, float)) or (
|
||||
isinstance(scores, dict) and len(scores) == 1
|
||||
):
|
||||
if isinstance(scores, dict):
|
||||
score = next(iter(scores.values()))
|
||||
else:
|
||||
score = scores
|
||||
scores_dict = {"overall": score}
|
||||
return {
|
||||
**scores_dict,
|
||||
"feedback": self._generate_feedback(scores_dict),
|
||||
}
|
||||
|
||||
# For models that return multiple dimensions
|
||||
return {**scores, "feedback": self._generate_feedback(scores)}
|
||||
else:
|
||||
# Fallback to original Agent self-evaluation
|
||||
solution_text = f"Solution: {solution}" if solution else ""
|
||||
prompt = self.EVALUATION_TEMPLATE.format(
|
||||
problem=problem, trace=trace, solution=solution_text
|
||||
)
|
||||
response = self.evaluate_agent.step( # type: ignore[union-attr]
|
||||
prompt, response_format=AgentTraceEvaluation
|
||||
)
|
||||
if response.msg.parsed is None:
|
||||
raise AttributeError("Failed to parse evaluation response")
|
||||
# Convert dict to AgentTraceEvaluation if needed
|
||||
if isinstance(response.msg.parsed, dict):
|
||||
evaluation = AgentTraceEvaluation(**response.msg.parsed)
|
||||
else:
|
||||
evaluation = response.msg.parsed
|
||||
|
||||
return evaluation.model_dump()
|
||||
|
||||
@retry_on_error()
|
||||
def generate_reasoning_trace_rejection(self, problem: str) -> str:
|
||||
r"""Generate multiple candidate reasoning traces for a problem and
|
||||
select the best one based on evaluation.
|
||||
|
||||
Args:
|
||||
problem (str): The problem text for generating a reasoning trace.
|
||||
|
||||
Returns:
|
||||
str: The best candidate trace that meets quality criteria, or the
|
||||
first candidate if none qualify.
|
||||
"""
|
||||
few_shot_examples = (
|
||||
f"Examples: {self.few_shot_examples}"
|
||||
if self.few_shot_examples
|
||||
else ""
|
||||
)
|
||||
prompt = self.REASONING_TEMPLATE.format(
|
||||
problem=problem, few_shot_examples=few_shot_examples
|
||||
)
|
||||
responses, candidate_traces = None, []
|
||||
if 'n' in self.reason_agent.model_backend.model_config_dict:
|
||||
self.reason_agent.model_backend.model_config_dict['n'] = (
|
||||
self.rejection_sampling_n
|
||||
)
|
||||
# Generate multiple candidate traces in one call using parameter n
|
||||
responses = self.reason_agent.step(prompt)
|
||||
# Extract cancidate traces
|
||||
candidate_traces = [choice.content for choice in responses.msgs]
|
||||
else:
|
||||
sampling_n = (
|
||||
self.rejection_sampling_n
|
||||
if self.rejection_sampling_n is not None
|
||||
else 1
|
||||
)
|
||||
for _i in range(sampling_n):
|
||||
trace = self.generate_reasoning_trace(problem)
|
||||
candidate_traces.append(trace)
|
||||
|
||||
best_trace = None
|
||||
best_avg_score = 0.01
|
||||
candidate_avg_scores = []
|
||||
for trace in candidate_traces:
|
||||
eval_results = self.evaluate_trace(problem, trace)
|
||||
# Remove feedback from scores
|
||||
scores = {k: v for k, v in eval_results.items() if k != "feedback"}
|
||||
# Compute average score (assuming at least one score exists)
|
||||
if scores:
|
||||
avg_score = sum(scores.values()) / len(scores)
|
||||
else:
|
||||
avg_score = 0.0
|
||||
candidate_avg_scores.append(avg_score)
|
||||
# If the candidate meets the threshold and is the best, select it
|
||||
if (
|
||||
self._check_score_threshold(scores)
|
||||
and avg_score > best_avg_score
|
||||
):
|
||||
best_trace = trace
|
||||
best_avg_score = avg_score
|
||||
if best_trace is None:
|
||||
best_trace = candidate_traces[
|
||||
candidate_avg_scores.index(max(candidate_avg_scores))
|
||||
]
|
||||
return best_trace
|
||||
|
||||
@retry_on_error()
|
||||
def improve_trace(
|
||||
self,
|
||||
problem: str,
|
||||
trace: str,
|
||||
feedback: str,
|
||||
solution: Optional[str] = None,
|
||||
) -> str:
|
||||
r"""Generate improved reasoning trace based on feedback.
|
||||
|
||||
Args:
|
||||
problem (str): The original problem text.
|
||||
trace (str): The current reasoning trace.
|
||||
feedback (str): Feedback for improving the trace.
|
||||
solution (Optional[str]): The solution to the problem, if provided.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
str: Improved reasoning trace.
|
||||
"""
|
||||
self.reason_agent.reset()
|
||||
solution_text = f"Solution: {solution}" if solution else ""
|
||||
prompt = self.IMPROVEMENT_TEMPLATE.format(
|
||||
problem=problem,
|
||||
trace=trace,
|
||||
feedback=feedback,
|
||||
solution=solution_text,
|
||||
)
|
||||
response = self.reason_agent.step(prompt)
|
||||
return response.msg.content
|
||||
|
||||
def validate_problem_format(self, problem: Dict) -> None:
|
||||
r"""Validate that a problem dictionary has the required format.
|
||||
|
||||
Args:
|
||||
problem (Dict): Problem dictionary to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the problem format is invalid.
|
||||
"""
|
||||
if not isinstance(problem, dict):
|
||||
raise ValueError("Problem must be a dictionary.")
|
||||
|
||||
# Check required problem field
|
||||
if "problem" not in problem:
|
||||
raise ValueError("Problem dictionary must contain 'problem' key.")
|
||||
if not isinstance(problem["problem"], str):
|
||||
raise ValueError("Problem 'problem' field must be a string.")
|
||||
|
||||
# Optional fields validation
|
||||
optional_fields: dict[str, type | tuple[type, ...]] = {
|
||||
"id": (str, int, type(None)),
|
||||
"type": str,
|
||||
"solution": str,
|
||||
}
|
||||
|
||||
for field, expected_type in optional_fields.items():
|
||||
if field in problem and not isinstance(
|
||||
problem[field], expected_type
|
||||
):
|
||||
type_name = (
|
||||
expected_type.__name__
|
||||
if hasattr(expected_type, '__name__')
|
||||
else str(expected_type)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Problem '{field}' must be of "
|
||||
f"type {type_name} if present."
|
||||
)
|
||||
|
||||
def _check_boxed_answers(self, solution: str, trace: str) -> bool:
|
||||
r"""Check if the answer in the trace matches the solution using the
|
||||
configured patterns.
|
||||
|
||||
Args:
|
||||
solution (str): The problem solution string.
|
||||
trace (str): The reasoning trace string.
|
||||
|
||||
Returns:
|
||||
bool: True if answers match, False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract content using the configured patterns
|
||||
solution_match = re.search(self.solution_pattern, solution, re.DOTALL)
|
||||
trace_match = re.search(self.trace_pattern, trace, re.DOTALL)
|
||||
|
||||
if solution_match and trace_match:
|
||||
# Clean up whitespace and normalize content
|
||||
solution_answer = solution_match.group(1).strip()
|
||||
trace_answer = trace_match.group(1).strip()
|
||||
return solution_answer == trace_answer
|
||||
|
||||
return False
|
||||
|
||||
def process_problem(
|
||||
self, problem: Dict, rationalization: bool = False
|
||||
) -> ProblemResult:
|
||||
r"""Process a single problem through the self-improving cot pipeline.
|
||||
|
||||
Args:
|
||||
problem (Dict): Problem dictionary containing the problem text.
|
||||
rationalization (bool, optional): Whether to use rationalization.
|
||||
(default: :obj:`False`)
|
||||
|
||||
Returns:
|
||||
ProblemResult: Results with final trace and history.
|
||||
|
||||
Raises:
|
||||
ValueError: If the problem format is invalid.
|
||||
"""
|
||||
# Validate problem format before processing
|
||||
self.validate_problem_format(problem)
|
||||
|
||||
problem_text = problem["problem"]
|
||||
solution_text = problem.get("solution", "")
|
||||
current_trace = None
|
||||
if self.rejection_sampling_n:
|
||||
current_trace = self.generate_reasoning_trace_rejection(
|
||||
problem_text
|
||||
)
|
||||
else:
|
||||
current_trace = self.generate_reasoning_trace(problem_text)
|
||||
improvement_history = []
|
||||
scores = {}
|
||||
|
||||
# Only evaluate if evaluate_agent or reward_model is set
|
||||
if self.evaluate_agent or self.reward_model:
|
||||
# Create batches for parallel evaluation
|
||||
batch_problems = [problem]
|
||||
batch_traces = [current_trace]
|
||||
batch_solutions = [solution_text]
|
||||
|
||||
# Evaluate current trace batch
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
eval_results = loop.run_until_complete(
|
||||
self._batch_evaluate_traces(
|
||||
batch_problems, batch_traces, batch_solutions
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Process evaluation results
|
||||
eval_dict = eval_results[-1] # Get latest evaluation
|
||||
scores = {k: v for k, v in eval_dict.items() if k != "feedback"}
|
||||
|
||||
# Record initial evaluation
|
||||
if self.evaluator:
|
||||
improvement_history.append(
|
||||
TraceIteration(
|
||||
iteration=0,
|
||||
trace=current_trace,
|
||||
evaluation=RewardTraceEvaluation(**eval_dict),
|
||||
)
|
||||
)
|
||||
else:
|
||||
improvement_history.append(
|
||||
TraceIteration(
|
||||
iteration=0,
|
||||
trace=current_trace,
|
||||
evaluation=AgentTraceEvaluation(
|
||||
**scores, feedback=eval_dict["feedback"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Only do improvement iterations if max_iterations > 0
|
||||
if self.max_iterations > 0:
|
||||
for iteration in range(0, self.max_iterations):
|
||||
# Check if quality threshold met
|
||||
if self._check_score_threshold(scores):
|
||||
break
|
||||
|
||||
# Generate improved trace
|
||||
if rationalization:
|
||||
current_trace = self.improve_trace(
|
||||
problem_text,
|
||||
current_trace,
|
||||
eval_dict["feedback"],
|
||||
solution_text,
|
||||
)
|
||||
else:
|
||||
current_trace = self.improve_trace(
|
||||
problem_text, current_trace, eval_dict["feedback"]
|
||||
)
|
||||
|
||||
# Evaluate improved trace
|
||||
batch_traces = [current_trace]
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
eval_results = loop.run_until_complete(
|
||||
self._batch_evaluate_traces(
|
||||
batch_problems, batch_traces, batch_solutions
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
eval_dict = eval_results[-1]
|
||||
scores = {
|
||||
k: v for k, v in eval_dict.items() if k != "feedback"
|
||||
}
|
||||
|
||||
# Record iteration history
|
||||
if self.evaluator:
|
||||
improvement_history.append(
|
||||
TraceIteration(
|
||||
iteration=iteration + 1,
|
||||
trace=current_trace,
|
||||
evaluation=RewardTraceEvaluation(**eval_dict),
|
||||
)
|
||||
)
|
||||
else:
|
||||
improvement_history.append(
|
||||
TraceIteration(
|
||||
iteration=iteration + 1,
|
||||
trace=current_trace,
|
||||
evaluation=AgentTraceEvaluation(
|
||||
**scores, feedback=eval_dict["feedback"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
boxed_answer_success = self._check_boxed_answers(
|
||||
problem.get("solution", ""), current_trace
|
||||
)
|
||||
|
||||
result = ProblemResult(
|
||||
id=problem.get("id", ""),
|
||||
type=problem.get("type", ""),
|
||||
problem=problem_text,
|
||||
solution=problem.get("solution", ""),
|
||||
final_trace=current_trace,
|
||||
agent_evaluate_success=self._check_score_threshold(scores)
|
||||
if scores
|
||||
else None,
|
||||
boxed_answer_success=boxed_answer_success,
|
||||
improvement_history=improvement_history,
|
||||
)
|
||||
|
||||
# Write result to file immediately if output path is specified
|
||||
if self.output_path:
|
||||
with self.lock:
|
||||
try:
|
||||
# Read existing results
|
||||
with open(self.output_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
cleaned_result = self.clean_json(result.model_dump())
|
||||
data['traces'].append(cleaned_result)
|
||||
self.safe_write_json(self.output_path, data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing result to file: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def generate(self, rationalization: bool = False) -> List[Dict[str, Any]]:
|
||||
r"""Execute the self-improving cot pipeline on all problems.
|
||||
|
||||
Process problems and return results. If output_path is specified,
|
||||
also save results to file.
|
||||
|
||||
Args:
|
||||
rationalization (bool, optional): Whether to use rationalization.
|
||||
(default: :obj:`False`)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of processed results
|
||||
"""
|
||||
# Pre-allocate results list
|
||||
self.reasoning_traces = []
|
||||
|
||||
# Process problems in batches
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
results = loop.run_until_complete(
|
||||
self._batch_process_problems(self.problems, rationalization)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
self.reasoning_traces = [result.model_dump() for result in results]
|
||||
return self.reasoning_traces
|
||||
|
||||
# Templates for generating reasoning, evaluation and improving them.
|
||||
REASONING_TEMPLATE = """Let's solve this step by step:
|
||||
Problem: {problem}
|
||||
1. First, let's understand what we're asked
|
||||
2. Let's break this down into parts
|
||||
3. Let's solve each part systematically
|
||||
4. Finally, let's verify our solution
|
||||
|
||||
{few_shot_examples}
|
||||
|
||||
Please show your complete reasoning process."""
|
||||
|
||||
EVALUATION_TEMPLATE = """Please evaluate this reasoning trace and
|
||||
provide scores and feedback in valid JSON format.
|
||||
|
||||
Problem: {problem}
|
||||
|
||||
{solution}
|
||||
|
||||
Reasoning Trace:
|
||||
{trace}
|
||||
|
||||
Evaluate for:
|
||||
1. Correctness (Is each step logically sound?)
|
||||
2. Clarity (Is the explanation clear and well-structured?)
|
||||
3. Completeness (Are all necessary steps included?)
|
||||
|
||||
Respond ONLY with a JSON object in this exact format:
|
||||
{{
|
||||
"correctness": <score between 0 and 1>,
|
||||
"clarity": <score between 0 and 1>,
|
||||
"completeness": <score between 0 and 1>,
|
||||
"feedback": "<specific feedback for improvement>"
|
||||
}}"""
|
||||
|
||||
IMPROVEMENT_TEMPLATE = """Based on this feedback, generate an
|
||||
improved reasoning trace:
|
||||
Problem: {problem}
|
||||
|
||||
{solution}
|
||||
|
||||
Previous Trace:
|
||||
{trace}
|
||||
|
||||
Feedback:
|
||||
{feedback}
|
||||
|
||||
Generate a new, improved reasoning trace that addresses the feedback."""
|
||||
36
camel/datagen/self_instruct/__init__.py
Normal file
36
camel/datagen/self_instruct/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .filter import (
|
||||
FILTER_REGISTRY,
|
||||
FilterFunction,
|
||||
InstructionFilter,
|
||||
KeywordFilter,
|
||||
LengthFilter,
|
||||
NonEnglishFilter,
|
||||
PunctuationFilter,
|
||||
RougeSimilarityFilter,
|
||||
)
|
||||
from .self_instruct import SelfInstructPipeline
|
||||
|
||||
__all__ = [
|
||||
'SelfInstructPipeline',
|
||||
'InstructionFilter',
|
||||
'NonEnglishFilter',
|
||||
'PunctuationFilter',
|
||||
'RougeSimilarityFilter',
|
||||
'FilterFunction',
|
||||
'KeywordFilter',
|
||||
'LengthFilter',
|
||||
'FILTER_REGISTRY',
|
||||
]
|
||||
34
camel/datagen/self_instruct/filter/__init__.py
Normal file
34
camel/datagen/self_instruct/filter/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .filter_function import (
|
||||
FilterFunction,
|
||||
KeywordFilter,
|
||||
LengthFilter,
|
||||
NonEnglishFilter,
|
||||
PunctuationFilter,
|
||||
RougeSimilarityFilter,
|
||||
)
|
||||
from .filter_registry import FILTER_REGISTRY
|
||||
from .instruction_filter import InstructionFilter
|
||||
|
||||
__all__ = [
|
||||
"LengthFilter",
|
||||
"NonEnglishFilter",
|
||||
"PunctuationFilter",
|
||||
"RougeSimilarityFilter",
|
||||
"FilterFunction",
|
||||
"KeywordFilter",
|
||||
"InstructionFilter",
|
||||
"FILTER_REGISTRY",
|
||||
]
|
||||
216
camel/datagen/self_instruct/filter/filter_function.py
Normal file
216
camel/datagen/self_instruct/filter/filter_function.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from rouge import Rouge
|
||||
|
||||
from camel.models.reward import BaseRewardModel
|
||||
|
||||
|
||||
class FilterFunction(ABC):
|
||||
r"""A base abstract class for filter functions.
|
||||
|
||||
Subclasses must implement the `apply` method, which determines whether
|
||||
a given instruction passes the filter criteria.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Evaluate the given instruction based on the filter's criteria.
|
||||
|
||||
Args:
|
||||
instruction (str): The instruction to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction passes the filter, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LengthFilter(FilterFunction):
|
||||
r"""Filters instructions based on their word count.
|
||||
|
||||
Args:
|
||||
min_len (int): The minimum word count required for an instruction.
|
||||
(default::obj:`5`)
|
||||
max_len (int): The maximum word count allowed for an instruction.
|
||||
(default::obj:`200`)
|
||||
"""
|
||||
|
||||
def __init__(self, min_len: int = 5, max_len: int = 200):
|
||||
self.min_len = min_len
|
||||
self.max_len = max_len
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): the instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True if the length of the instruction is within the range
|
||||
of [min_len, max_len]
|
||||
"""
|
||||
word_count = len(instruction.split())
|
||||
return self.min_len <= word_count <= self.max_len
|
||||
|
||||
|
||||
class KeywordFilter(FilterFunction):
|
||||
r"""Filters instructions that contain specific undesirable keywords.
|
||||
|
||||
Args:
|
||||
keywords (List[str]): A list of keywords to filter out.
|
||||
"""
|
||||
|
||||
def __init__(self, keywords: List[str]):
|
||||
self.keywords = [keyword.lower() for keyword in keywords]
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): the instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True Instruction must NOT contain any of the keywords.
|
||||
"""
|
||||
lower_instr = instruction.lower()
|
||||
return not any(keyword in lower_instr for keyword in self.keywords)
|
||||
|
||||
|
||||
class PunctuationFilter(FilterFunction):
|
||||
r"""Filters instructions that begin with a non-alphanumeric character."""
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): the instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction does not start with punctuation.
|
||||
"""
|
||||
return not re.match(r'^[^\w\s]', instruction)
|
||||
|
||||
|
||||
class NonEnglishFilter(FilterFunction):
|
||||
r"""Filters instructions that do not begin with English letters."""
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): the instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction starts with an English letter.
|
||||
"""
|
||||
return bool(re.match(r'^[A-Za-z]', instruction))
|
||||
|
||||
|
||||
class RougeSimilarityFilter(FilterFunction):
|
||||
r"""Filters instructions that are too similar to existing instructions
|
||||
based on ROUGE scores.
|
||||
|
||||
Args:
|
||||
existing_instructions (List[str]): A list of existing instructions to
|
||||
compare against.
|
||||
threshold (float): The similarity threshold for filtering.
|
||||
(default::obj:`0.7`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, existing_instructions: List[str], threshold: float = 0.7
|
||||
):
|
||||
self.existing_instructions = existing_instructions
|
||||
self.threshold = threshold
|
||||
self.rouge = Rouge()
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): the instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction's similarity to any existing
|
||||
instruction is below the threshold.
|
||||
"""
|
||||
if not self.existing_instructions:
|
||||
return True
|
||||
|
||||
for existing_instr in self.existing_instructions:
|
||||
scores = self.rouge.get_scores(instruction, existing_instr)
|
||||
score = scores[0]['rouge-l']['f']
|
||||
if score > self.threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class RewardModelFilter(FilterFunction):
|
||||
r"""Filters instructions based on scores provided by a reward model.
|
||||
|
||||
Args:
|
||||
reward_model (BaseRewardModel): The reward model used to evaluate
|
||||
the instructions.
|
||||
threshold (float): The minimum score required for an instruction
|
||||
to pass the filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reward_model: BaseRewardModel,
|
||||
threshold: float = 0.5,
|
||||
):
|
||||
self.prompt = ""
|
||||
self.reward_model = reward_model
|
||||
self.threshold = threshold
|
||||
|
||||
def apply(self, instruction: str) -> bool:
|
||||
r"""Filter the instruction
|
||||
|
||||
Args:
|
||||
instruction (str): The instruction to be filtered.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction's score is above the threshold.
|
||||
|
||||
Raises:
|
||||
ValueError: ValueError: If `score_types` is empty or if the
|
||||
required score is not found in `scores`.
|
||||
"""
|
||||
|
||||
data = [
|
||||
{"role": "user", "content": self.prompt},
|
||||
{"role": "assistant", "content": instruction},
|
||||
]
|
||||
scores = self.reward_model.evaluate(data)
|
||||
score_types = self.reward_model.get_scores_types()
|
||||
if not score_types:
|
||||
raise ValueError("No score types available from the reward model.")
|
||||
|
||||
score_type = score_types[0]
|
||||
score = scores.get(score_type, None)
|
||||
|
||||
if score is None:
|
||||
raise ValueError(
|
||||
f"Score type '{score_type}' is not found in the "
|
||||
"evaluation scores."
|
||||
)
|
||||
|
||||
return score >= self.threshold
|
||||
56
camel/datagen/self_instruct/filter/filter_registry.py
Normal file
56
camel/datagen/self_instruct/filter/filter_registry.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from .filter_function import (
|
||||
FilterFunction,
|
||||
KeywordFilter,
|
||||
LengthFilter,
|
||||
NonEnglishFilter,
|
||||
PunctuationFilter,
|
||||
RewardModelFilter,
|
||||
RougeSimilarityFilter,
|
||||
)
|
||||
|
||||
FILTER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], FilterFunction]] = {
|
||||
"length": lambda kwargs: LengthFilter(
|
||||
min_len=kwargs.get("min_len", 5), max_len=kwargs.get("max_len", 200)
|
||||
),
|
||||
"keyword": lambda kwargs: KeywordFilter(
|
||||
keywords=kwargs.get("keywords", ["image", "data"])
|
||||
),
|
||||
"punctuation": lambda kwargs: PunctuationFilter(),
|
||||
"non_english": lambda kwargs: NonEnglishFilter(),
|
||||
"rouge_similarity": lambda kwargs: RougeSimilarityFilter(
|
||||
existing_instructions=kwargs.get("existing_instructions", []),
|
||||
threshold=kwargs.get("threshold", 0.7),
|
||||
),
|
||||
"reward": lambda kwargs: RewardModelFilter(
|
||||
reward_model=kwargs.get("reward_model"), # type:ignore[arg-type]
|
||||
threshold=kwargs.get("threshold", 0.7),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_filter(
|
||||
name: str, constructor: Callable[[Dict[str, Any]], FilterFunction]
|
||||
):
|
||||
r"""Registers a new filter constructor in FILTER_REGISTRY.
|
||||
|
||||
Args:
|
||||
name (str): Unique name of the filter.
|
||||
constructor (Callable[[Dict[str, Any]], FilterFunction]): Function to
|
||||
create the filter using a dictionary of parameters.
|
||||
"""
|
||||
FILTER_REGISTRY[name] = constructor
|
||||
97
camel/datagen/self_instruct/filter/instruction_filter.py
Normal file
97
camel/datagen/self_instruct/filter/instruction_filter.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from camel.logger import get_logger
|
||||
|
||||
from .filter_function import FilterFunction, RewardModelFilter
|
||||
from .filter_registry import FILTER_REGISTRY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InstructionFilter:
|
||||
def __init__(
|
||||
self,
|
||||
filters_config: Dict[str, Dict[str, Any]],
|
||||
stop_on_first_failure: bool = False,
|
||||
):
|
||||
r"""Initialize the InstructionFilter with a dictionary of filter
|
||||
configurations.
|
||||
|
||||
Args:
|
||||
filters_config(Dict[str, Dict[str, Any]]):
|
||||
Example filters_config:
|
||||
{
|
||||
"length": {"min_len": 5, "max_len": 100},
|
||||
"keyword": {"keywords": ["image", "video"]},
|
||||
"non_english": {},
|
||||
"rouge_similarity": {
|
||||
"existing_instructions": ["Some existing text"],
|
||||
"threshold": 0.6
|
||||
}
|
||||
}
|
||||
Each key in filters_config corresponds to a filter name
|
||||
(registered in FILTER_REGISTRY).
|
||||
Each value is a dict of parameters for that filter.
|
||||
stop_on_first_failure (bool): If True, stops checking filters after
|
||||
the first failure.
|
||||
"""
|
||||
self.filters: List[FilterFunction] = []
|
||||
for filter_name, params in filters_config.items():
|
||||
if filter_name not in FILTER_REGISTRY:
|
||||
raise ValueError(f"Unknown filter function: {filter_name}")
|
||||
self.filters.append(FILTER_REGISTRY[filter_name](params))
|
||||
self.stop_on_first_failure: bool = stop_on_first_failure
|
||||
|
||||
def add_filter(self, filter_function: FilterFunction):
|
||||
r"""Add a custom filter function to the InstructionFilter.
|
||||
This allows adding filters that are not in the registry.
|
||||
|
||||
Args:
|
||||
filter_function (FilterFunction): The filter function to be added
|
||||
"""
|
||||
self.filters.append(filter_function)
|
||||
|
||||
def filter(
|
||||
self, prompt: str, instruction: str, return_details: bool = False
|
||||
) -> Union[bool, Tuple[bool, List[str]]]:
|
||||
r"""Check if the given instruction passes all filter functions.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt of generating the instruction.
|
||||
instruction (str): The instruction to evaluate.
|
||||
return_details (bool): If True, returns a tuple (bool, List[str])
|
||||
where the list contains the names of filters that failed.
|
||||
(default::obj:`False`)
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction passes all filters, False otherwise.
|
||||
OR (bool, List[str]) if return_details is True.
|
||||
"""
|
||||
failed_filters = []
|
||||
for f in self.filters:
|
||||
if isinstance(f, RewardModelFilter):
|
||||
f.prompt = prompt
|
||||
if not f.apply(instruction):
|
||||
failed_filters.append(type(f).__name__)
|
||||
logger.warning(
|
||||
f"{type(f).__name__} failed instruction: {instruction}"
|
||||
)
|
||||
if self.stop_on_first_failure:
|
||||
break
|
||||
|
||||
if return_details:
|
||||
return len(failed_filters) == 0, failed_filters
|
||||
return len(failed_filters) == 0
|
||||
445
camel/datagen/self_instruct/self_instruct.py
Normal file
445
camel/datagen/self_instruct/self_instruct.py
Normal file
@@ -0,0 +1,445 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
|
||||
from .filter import RougeSimilarityFilter
|
||||
from .filter.instruction_filter import InstructionFilter
|
||||
from .templates import SelfInstructTemplates
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SelfInstructPipeline:
|
||||
r"""A pipeline to generate and manage machine-generated instructions for
|
||||
tasks, combining human and machine task samples.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent used to interact and generate
|
||||
instructions.
|
||||
seed (str): The path to the human-written instructions.
|
||||
num_machine_instructions (int): Number of machine-generated
|
||||
instructions to generate. (default::obj:`5`)
|
||||
data_output_path (Optional[str]): Path to save the generated data.
|
||||
(default::obj:`./data_output.json`)
|
||||
human_to_machine_ratio (tuple): Ratio of human to machine tasks used
|
||||
for instruction generation. (default::obj:`(6, 2)`)
|
||||
instruction_filter (InstructionFilter): A filter to validate
|
||||
generated instructions. (default::obj:`None`)
|
||||
filter_config (Optional[Dict[str, Dict[str, Any]]]): configuration
|
||||
for the filter functions registered in FILE_REGISTRY.
|
||||
(default::obj:`None`)
|
||||
stop_on_first_failure (bool): If True, stops checking filters after
|
||||
the first failure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
seed: str,
|
||||
num_machine_instructions: int = 5,
|
||||
data_output_path: Optional[str] = './data_output.json',
|
||||
human_to_machine_ratio: tuple = (6, 2),
|
||||
instruction_filter: Optional[InstructionFilter] = None,
|
||||
filter_config: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
stop_on_first_failure: bool = False,
|
||||
):
|
||||
self.agent = agent
|
||||
self.num_machine_instructions = num_machine_instructions
|
||||
self.data_output_path = data_output_path
|
||||
self.human_to_machine_ratio = human_to_machine_ratio
|
||||
self.human_tasks: List[Dict] = []
|
||||
self.machine_tasks: List[Dict] = []
|
||||
self.load_seed(seed)
|
||||
default_config: Dict[str, Dict[str, Any]] = {
|
||||
"length": {},
|
||||
"keyword": {},
|
||||
"punctuation": {},
|
||||
"non_english": {},
|
||||
"rouge_similarity": {},
|
||||
}
|
||||
|
||||
if instruction_filter is not None:
|
||||
# custom
|
||||
self.instruction_filter = instruction_filter
|
||||
else:
|
||||
# default
|
||||
config_to_use = (
|
||||
filter_config if filter_config is not None else default_config
|
||||
)
|
||||
self.instruction_filter = InstructionFilter(
|
||||
config_to_use, stop_on_first_failure
|
||||
)
|
||||
|
||||
def load_seed(self, path: str):
|
||||
r"""Load seed tasks from a file. Defaults to a predefined seed file if
|
||||
no path is provided.
|
||||
|
||||
Args:
|
||||
path (str): Path to the seed file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the seed file does not exist.
|
||||
"""
|
||||
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self.human_tasks.append(json.loads(line))
|
||||
else:
|
||||
raise FileNotFoundError(f"Seed file not found at path: {path}")
|
||||
|
||||
def sample_human_tasks(self, count: int) -> List[dict]:
|
||||
r"""Sample a specified number of human tasks from the loaded seed.
|
||||
|
||||
Args:
|
||||
count (int): Number of human tasks to sample.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of sampled human tasks.
|
||||
"""
|
||||
return random.sample(
|
||||
self.human_tasks, min(count, len(self.human_tasks))
|
||||
)
|
||||
|
||||
def sample_machine_tasks(self, count: int) -> List[dict]:
|
||||
r"""Sample a specified number of machine tasks.
|
||||
|
||||
Args:
|
||||
count (int): Number of machine tasks to sample.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of sampled machine tasks, with placeholders if
|
||||
insufficient tasks are available.
|
||||
"""
|
||||
available_machine_tasks = len(self.machine_tasks)
|
||||
if available_machine_tasks < count:
|
||||
sampled_tasks = self.machine_tasks.copy()
|
||||
placeholders_needed = count - available_machine_tasks
|
||||
sampled_tasks.extend(
|
||||
[{'instruction': ""} for _ in range(placeholders_needed)]
|
||||
)
|
||||
return sampled_tasks
|
||||
|
||||
return random.sample(self.machine_tasks, count)
|
||||
|
||||
def generate_machine_instruction(self) -> List:
|
||||
r"""Generate a machine instruction using the agent.
|
||||
|
||||
Combines human and machine tasks based on the configured ratio to
|
||||
create a prompt for instruction generation.
|
||||
|
||||
Returns:
|
||||
List: The prompt and a machine-generated instruction.
|
||||
"""
|
||||
|
||||
sampled_human_tasks = self.sample_human_tasks(
|
||||
self.human_to_machine_ratio[0]
|
||||
)
|
||||
sampled_machine_tasks = self.sample_machine_tasks(
|
||||
self.human_to_machine_ratio[1]
|
||||
)
|
||||
prompt = "Below are some tasks:\n\n"
|
||||
|
||||
for idx, task in enumerate(sampled_human_tasks, 1):
|
||||
prompt += f"Task {idx}: {task['instruction']}\n"
|
||||
|
||||
current_task_number = len(sampled_human_tasks) + 1
|
||||
for idx, task in enumerate(sampled_machine_tasks, current_task_number):
|
||||
prompt += f"Task {idx}: {task['instruction']}\n"
|
||||
|
||||
task_num = len(sampled_human_tasks) + len(sampled_machine_tasks) + 1
|
||||
prompt += f"Task {task_num}:"
|
||||
prompt += (
|
||||
"\nNow, please produce exactly one new task that fits the "
|
||||
"style of the ones above.\n Do not include any task numbering or "
|
||||
"labels like 'Task X:'. Just write the task itself.\n"
|
||||
"The task should be a single sentence.\n\n"
|
||||
)
|
||||
|
||||
response = self.agent.step(prompt)
|
||||
self.agent.reset()
|
||||
generated_tasks = [
|
||||
line.strip()
|
||||
for line in response.msgs[0].content.split("\n")
|
||||
if line.strip()
|
||||
]
|
||||
return [prompt, generated_tasks[0]]
|
||||
|
||||
def identify_instruction(self, instruction: str) -> bool:
|
||||
r"""Determine if the given instruction is a classification task.
|
||||
|
||||
Args:
|
||||
instruction (str): The instruction to classify.
|
||||
|
||||
Returns:
|
||||
bool: True if the instruction is a classification task,
|
||||
otherwise False.
|
||||
"""
|
||||
clf_prompt = (
|
||||
SelfInstructTemplates.clf_template
|
||||
+ f"Task: {instruction}\nIs it classification?"
|
||||
+ "\nRespond in the following structured format:"
|
||||
"\n{\n \"answer\": true\n}\n"
|
||||
"or\n"
|
||||
"{\n \"answer\": false\n}\n"
|
||||
)
|
||||
response = self.agent.step(clf_prompt)
|
||||
self.agent.reset()
|
||||
try:
|
||||
structured_response = AgentResponse.parse_raw(
|
||||
response.msgs[0].content.strip()
|
||||
)
|
||||
return structured_response.answer
|
||||
except ValueError as e:
|
||||
logger.error(f"Error parsing agent response: {e}")
|
||||
return False
|
||||
|
||||
def generate_machine_instances(self):
|
||||
r"""Generate instances for each machine task based on its
|
||||
classification status.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting output generation: target {len(self.machine_tasks)} "
|
||||
f"instructions"
|
||||
)
|
||||
attempt_count = 0
|
||||
for instruction in self.machine_tasks:
|
||||
instance = self.generate_machine_instance(
|
||||
instruction['instruction'], instruction['is_classification']
|
||||
)
|
||||
instruction['instances'] = instance
|
||||
attempt_count += 1
|
||||
logger.info(
|
||||
f"Attempt[Output]: Progress {attempt_count}/"
|
||||
f"{len(self.machine_tasks)} instructions"
|
||||
)
|
||||
|
||||
def generate_machine_instance(
|
||||
self, instruction: str, classification: bool
|
||||
) -> list[dict]:
|
||||
r"""Generate instances for a given instruction.
|
||||
|
||||
Args:
|
||||
instruction (str): The instruction to create instances for.
|
||||
classification (bool): Whether the instruction is a classification
|
||||
task.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of generated instances in input-output format.
|
||||
"""
|
||||
if classification:
|
||||
prompt = (
|
||||
SelfInstructTemplates.output_first_template_for_clf.format(
|
||||
instruction=instruction
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt = SelfInstructTemplates.input_first_template_for_gen.format(
|
||||
instruction=instruction
|
||||
)
|
||||
|
||||
response = self.agent.step(prompt)
|
||||
self.agent.reset()
|
||||
generated_text = response.msgs[0].content.strip()
|
||||
|
||||
if classification:
|
||||
return self.parse_classification_output(generated_text)
|
||||
else:
|
||||
return self.parse_non_classification_output(generated_text)
|
||||
|
||||
def parse_classification_output(
|
||||
self, generated_text: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""Parse the generated text for classification tasks into input-output
|
||||
pairs.
|
||||
|
||||
Args:
|
||||
generated_text (str): The raw text generated by the agent for
|
||||
classification tasks.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries with 'input' and
|
||||
'output' keys.
|
||||
"""
|
||||
instances = []
|
||||
lines = generated_text.split("\n")
|
||||
current_label = None
|
||||
current_input = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("Class label:"):
|
||||
if current_label and current_input:
|
||||
instances.append(
|
||||
{
|
||||
"input": current_input.strip(),
|
||||
"output": current_label.strip(),
|
||||
}
|
||||
)
|
||||
|
||||
current_label = line[len("Class label:") :].strip()
|
||||
current_input = None
|
||||
else:
|
||||
if current_input is None:
|
||||
current_input = line
|
||||
else:
|
||||
current_input += f"\n{line}"
|
||||
if current_label and current_input:
|
||||
instances.append(
|
||||
{
|
||||
"input": current_input.strip(),
|
||||
"output": current_label.strip(),
|
||||
}
|
||||
)
|
||||
|
||||
return instances
|
||||
|
||||
def parse_non_classification_output(
|
||||
self, generated_text: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""Parse the generated text for non-classification tasks into
|
||||
input-output pairs.
|
||||
|
||||
Args:
|
||||
generated_text (str): The raw text generated by the agent for
|
||||
non-classification tasks.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries with 'input' and
|
||||
'output' keys.
|
||||
"""
|
||||
instances = []
|
||||
prev = 0
|
||||
lines = generated_text.split("\n")
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
if line.startswith("Example "):
|
||||
prev = i + 1
|
||||
|
||||
elif line.startswith("Output:"):
|
||||
instance_input = '\n'.join(lines[prev:i]).strip()
|
||||
if instance_input.startswith("Input: "):
|
||||
instance_input = instance_input[len("Input: ") :].strip()
|
||||
else:
|
||||
instance_input = instance_input.strip()
|
||||
|
||||
instance_output = line[len("Output:") :].strip()
|
||||
i += 1
|
||||
while i < len(lines) and not lines[i].strip().startswith(
|
||||
"Example "
|
||||
):
|
||||
instance_output += '\n' + lines[i].strip()
|
||||
i += 1
|
||||
i -= 1
|
||||
|
||||
instance_output = instance_output.strip()
|
||||
|
||||
instances.append(
|
||||
{"input": instance_input, "output": instance_output}
|
||||
)
|
||||
|
||||
prev = i + 1
|
||||
i += 1
|
||||
|
||||
if not instances:
|
||||
instances.append({"input": "", "output": "No valid output found."})
|
||||
|
||||
return instances
|
||||
|
||||
def construct_data(self):
|
||||
r"""Save the machine-generated tasks to the specified output path
|
||||
in JSON format.
|
||||
"""
|
||||
with open(self.data_output_path, 'w') as f:
|
||||
json.dump(self.machine_tasks, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def generate(self, timeout_minutes=600):
|
||||
r"""Execute the entire pipeline to generate machine instructions
|
||||
and instances.
|
||||
|
||||
Args:
|
||||
timeout_minutes (int): Maximum time in minutes to run the
|
||||
generation process before timing out. (default: :obj:`600`)
|
||||
"""
|
||||
start_time = time.time()
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
logger.info(
|
||||
f"Starting instruction generation: target "
|
||||
f"{self.num_machine_instructions} instructions"
|
||||
)
|
||||
while len(self.machine_tasks) < self.num_machine_instructions:
|
||||
# Check for timeout
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.info(
|
||||
f"Generation timed out after {elapsed / 60:.1f} minutes. "
|
||||
f"Generated {len(self.machine_tasks)}/"
|
||||
f"{self.num_machine_instructions} instructions."
|
||||
)
|
||||
break
|
||||
prompt, instruction = self.generate_machine_instruction()
|
||||
existing_instructions = [
|
||||
t["instruction"] for t in self.human_tasks
|
||||
] + [t["instruction"] for t in self.machine_tasks]
|
||||
for f in self.instruction_filter.filters:
|
||||
if isinstance(f, RougeSimilarityFilter):
|
||||
f.existing_instructions = existing_instructions
|
||||
if self.instruction_filter.filter(prompt, instruction):
|
||||
instruction_dict = {
|
||||
"id": f"machine_task_{len(self.machine_tasks) + 1}",
|
||||
"instruction": instruction,
|
||||
"is_classification": self.identify_instruction(
|
||||
instruction
|
||||
),
|
||||
}
|
||||
self.machine_tasks.append(instruction_dict)
|
||||
logger.info(
|
||||
f"Attempt[Instruction]: Progress "
|
||||
f"{len(self.machine_tasks)}/"
|
||||
f"{self.num_machine_instructions} "
|
||||
f"instructions"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Instruction failed filters. Skipping instruction: "
|
||||
f"{instruction}"
|
||||
)
|
||||
self.generate_machine_instances()
|
||||
self.construct_data()
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
answer: bool = Field(
|
||||
...,
|
||||
description="Indicates whether the task is "
|
||||
"classification (True/False).",
|
||||
)
|
||||
382
camel/datagen/self_instruct/templates.py
Normal file
382
camel/datagen/self_instruct/templates.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
@dataclass(frozen=True)
|
||||
class SelfInstructTemplates:
|
||||
r"""Contains templates prompts for self-instruct data generation"""
|
||||
|
||||
clf_template = """ '''Can the following task be regarded as a classification task with finite output labels?
|
||||
|
||||
Task: Given my personality and the job, tell me if I would be suitable.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Give me an example of a time when you had to use your sense of humor.
|
||||
Is it classification? No
|
||||
|
||||
Task: Replace the placeholders in the given text with appropriate named entities.
|
||||
Is it classification? No
|
||||
|
||||
Task: Fact checking - tell me if the statement is true, false, or unknown, based on your knowledge and common sense.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Return the SSN number for the person.
|
||||
Is it classification? No
|
||||
|
||||
Task: Detect if the Reddit thread contains hate speech.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Analyze the sentences below to identify biases.
|
||||
Is it classification? No
|
||||
|
||||
Task: Select the longest sentence in terms of the number of words in the paragraph, output the sentence index.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Find out the toxic word or phrase in the sentence.
|
||||
Is it classification? No
|
||||
|
||||
Task: Rank these countries by their population.
|
||||
Is it classification? No
|
||||
|
||||
Task: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include: Music, Sports, Politics, Tech, Finance, Basketball, Soccer, Tennis, Entertainment, Digital Game, World News. Output its categories one by one, seperated by comma.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Given the name of an exercise, explain how to do it.
|
||||
Is it classification? No
|
||||
|
||||
Task: Select the oldest person from the list.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Find the four smallest perfect numbers.
|
||||
Is it classification? No
|
||||
|
||||
Task: Does the information in the document supports the claim? You can answer "Support" or "Unsupport".
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Create a detailed budget for the given hypothetical trip.
|
||||
Is it classification? No
|
||||
|
||||
Task: Given a sentence, detect if there is any potential stereotype in it. If so, you should explain the stereotype. Else, output no.
|
||||
Is it classification? No
|
||||
|
||||
Task: Explain the following idiom to me, and try to give me some examples.
|
||||
Is it classification? No
|
||||
|
||||
Task: Is there anything I can eat for a breakfast that doesn't include eggs, yet includes protein, and has roughly 700-1000 calories?
|
||||
Is it classification? No
|
||||
|
||||
Task: Answer the following multiple choice question. Select A, B, C, or D for the final answer.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Decide whether the syllogism is logically sound.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: How can individuals and organizations reduce unconscious bias?
|
||||
Is it classification? No
|
||||
|
||||
Task: What are some things you can do to de-stress?
|
||||
Is it classification? No
|
||||
|
||||
Task: Find out the largest one from a set of numbers. Output the number directly.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Replace the <mask> token in the text with proper words that are consistent with the context. You can use multiple words for each <mask> token.
|
||||
Is it classification? No
|
||||
|
||||
Task: Write a cover letter based on the given facts.
|
||||
Is it classification? No
|
||||
|
||||
Task: Identify the pos tag of the word in the given sentence.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: Write a program to compute the sum of integers from k to n.
|
||||
Is it classification? No
|
||||
|
||||
Task: In this task, you need to compare the meaning of the two sentences and tell if they are the same. Output yes or no.
|
||||
Is it classification? Yes
|
||||
|
||||
Task: To make the pairs have the same analogy, write the fourth word.
|
||||
Is it classification? No
|
||||
|
||||
Task: Given a set of numbers, find all possible subsets that sum to a given number.
|
||||
Is it classification? No
|
||||
|
||||
"""
|
||||
output_first_template_for_clf = '''You are given a classification instruction.
|
||||
|
||||
Produce multiple labeled examples following the format below. For each example:
|
||||
- Begin with a "Class label:" line identifying one possible category.
|
||||
- Follow that with one line specifying the example input (e.g., "Sentence:", "Dialogue:", "Opinion:", or "Email:").
|
||||
- The content after these lines should serve as an illustrative example of that label.
|
||||
|
||||
Do not restate or include the "Task:" line. Do not add additional commentary. Just produce the labeled examples.
|
||||
|
||||
Example format (no initial task line, task will be provided) when task is Task: Classify the sentiment of the sentence into positive, negative, or mixed.:
|
||||
Class label: mixed
|
||||
Sentence: I enjoy the flavor of the restaurant but their service is too slow.
|
||||
Class label: Positive
|
||||
Sentence: I had a great day today. The weather was beautiful and I spent time with friends and family.
|
||||
Class label: Negative
|
||||
Sentence: I was really disappointed by the latest superhero movie. I would not recommend it to anyone.
|
||||
|
||||
Below are more examples:
|
||||
|
||||
Task: Given a dialogue, classify whether the user is satisfied with the service. You should respond with "Satisfied" or "Unsatisfied".
|
||||
Class label: Satisfied
|
||||
Dialogue:
|
||||
- Agent: Thank you for your feedback. We will work to improve our service in the future.
|
||||
- Customer: I am happy with the service you provided. Thank you for your help.
|
||||
Class label: Unsatisfied
|
||||
Dialogue:
|
||||
- Agent: I am sorry we will cancel that order for you, and you will get a refund within 7 business days.
|
||||
- Customer: oh that takes too long. I want you to take quicker action on this.
|
||||
|
||||
Task: Given some political opinions, classify whether the person belongs to Democrats or Republicans.
|
||||
Class label: Democrats
|
||||
Opinion: I believe that everyone should have access to quality healthcare regardless of their income level.
|
||||
Class label: Republicans
|
||||
Opinion: I believe that people should be able to keep more of their hard-earned money and should not be taxed at high rates.
|
||||
|
||||
Task: Tell me if the following email is a promotion email or not.
|
||||
Class label: Promotion
|
||||
Email: Check out our amazing new sale! We've got discounts on all of your favorite products.
|
||||
Class label: Not Promotion
|
||||
Email: We hope you are doing well. Let us know if you need any help.
|
||||
|
||||
Task: Detect if the Reddit thread contains hate speech.
|
||||
Class label: Hate Speech
|
||||
Thread: All people of color are stupid and should not be allowed to vote.
|
||||
Class label: Not Hate Speech
|
||||
Thread: The best way to cook a steak on the grill.
|
||||
|
||||
Task: Does the information in the document supports the claim? You can answer "Support" or "Unsupport".
|
||||
Class label: Unsupport
|
||||
Document: After a record-breaking run that saw mortgage rates plunge to all-time lows and home prices soar to new highs, the U.S. housing market finally is slowing. While demand and price gains are cooling, any correction is likely to be a modest one, housing economists and analysts say. No one expects price drops on the scale of the declines experienced during the Great Recession.
|
||||
Claim: The US housing market is going to crash soon.
|
||||
Class label: Support
|
||||
Document: The U.S. housing market is showing signs of strain, with home sales and prices slowing in many areas. Mortgage rates have risen sharply in recent months, and the number of homes for sale is increasing. This could be the beginning of a larger downturn, with some economists predicting a potential housing crash in the near future.
|
||||
Claim: The US housing market is going to crash soon.
|
||||
|
||||
Task: Answer the following multiple-choice question. Select A, B, C, or D for the final answer.
|
||||
Class label: C
|
||||
Question: What is the capital of Germany?
|
||||
A. London
|
||||
B. Paris
|
||||
C. Berlin
|
||||
D. Rome
|
||||
Class label: D
|
||||
Question: What is the largest planet in our solar system?
|
||||
A) Earth
|
||||
B) Saturn
|
||||
C) Mars
|
||||
D) Jupiter
|
||||
Class label: A
|
||||
Question: What is the process by which plants make their own food through photosynthesis?
|
||||
A) Respiration
|
||||
B) Fermentation
|
||||
C) Digestion
|
||||
D) Metabolism
|
||||
Class label: B
|
||||
Question: Who wrote the novel "The Great Gatsby"?
|
||||
A) Ernest Hemingway
|
||||
B) F. Scott Fitzgerald
|
||||
C) J.D. Salinger
|
||||
D) Mark Twain
|
||||
|
||||
Task: You need to read a code and detect if there is a syntax error or not. Output true if there is an error, output false if there is not.
|
||||
Class label: true
|
||||
Code:
|
||||
def quick_sort(arr):
|
||||
if len(arr) < 2
|
||||
return arr
|
||||
Class label: False
|
||||
Code:
|
||||
def calculate_average(numbers):
|
||||
total = 0
|
||||
for number in numbers:
|
||||
total += number
|
||||
return total / len(numbers)
|
||||
|
||||
Task: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include Sports and Politics. Output its categories one by one, separated by a comma.
|
||||
Class label: Sports
|
||||
Article: The Golden State Warriors have won the NBA championship for the second year in a row.
|
||||
Class label: Politics
|
||||
Article: The United States has withdrawn from the Paris Climate Agreement.
|
||||
Class label: Politics, Sports
|
||||
Article: The government has proposed cutting funding for youth sports programs.
|
||||
|
||||
Task: Given a credit card statement, the cardholder's spending habits, and the account balance, classify whether the cardholder is at risk of defaulting on their payments or not.
|
||||
Class label: At risk
|
||||
Credit card statement: Purchases at high-end clothing stores and luxury hotels.
|
||||
Cardholder's spending habits: Frequent purchases at luxury brands and high-end establishments.
|
||||
Account balance: Over the credit limit and multiple missed payments.
|
||||
Class label: Not at risk
|
||||
Credit card statement: Purchases at grocery stores and gas stations.
|
||||
Cardholder's spending habits: Regular purchases for necessary expenses and occasional dining out.
|
||||
Account balance: Slightly below the credit limit and no missed payments.
|
||||
|
||||
Task: Given a social media post, the hashtags used, and a topic. classify whether the post is relevant to the topic or not.
|
||||
Class label: Relevant
|
||||
Post: I can't believe the government is still not taking action on climate change. It's time for us to take matters into our own hands.
|
||||
Hashtags: #climatechange #actnow
|
||||
Topic: Climate change
|
||||
Class label: Not relevant
|
||||
Post: I just bought the new iPhone and it is amazing!
|
||||
Hashtags: #apple #technology
|
||||
Topic: Travel
|
||||
|
||||
Task: The answer will be 'yes' if the provided sentence contains an explicit mention that answers the given question. Otherwise, answer 'no'.
|
||||
Class label: Yes
|
||||
Sentence: Jack played basketball for an hour after school.
|
||||
Question: How long did Jack play basketball?
|
||||
Class label: No
|
||||
Sentence: The leaders of the Department of Homeland Security now appear before 88 committees and subcommittees of Congress.
|
||||
Question: How often are they required to appear?
|
||||
|
||||
Task: Tell me what's the second largest city by population in Canada.
|
||||
Class label: Montreal
|
||||
|
||||
Task: Classifying different types of mathematical equations, such as linear, and quadratic equations, based on the coefficients and terms in the equation.
|
||||
Class label: Linear equation
|
||||
Equation: y = 2x + 5
|
||||
Class label: Quadratic equation
|
||||
Equation: y = x^2 - 4x + 3
|
||||
|
||||
Task: Tell me the first number of the given list.
|
||||
Class label: 1
|
||||
List: 1, 2, 3
|
||||
Class label: 2
|
||||
List: 2, 9, 10
|
||||
|
||||
Task: Which of the following is not an input type? (a) number (b) date (c) phone number (d) email address (e) all of these are valid inputs.
|
||||
Class label: (e)
|
||||
|
||||
Now, using the given instruction, produce several formatted examples accordingly:
|
||||
Task: {instruction}
|
||||
'''
|
||||
|
||||
input_first_template_for_gen = '''You will be given a task,
|
||||
Your job is to generate at most two example instances demonstrating how to
|
||||
perform this task. For each instance:
|
||||
- If the task requires input (as an actual example of the task), provide it.
|
||||
- If the task can be answered directly without requiring input, omit the input section.
|
||||
|
||||
Example 1
|
||||
Input: [Provide input here if needed, otherwise omit this section]
|
||||
Output: [Provide the correct output]
|
||||
|
||||
Example 2
|
||||
Input: [Provide input here if needed, otherwise omit this section]
|
||||
Output: [Provide the correct output]
|
||||
|
||||
Do not include any additional commentary, explanations, or more than two instances.
|
||||
|
||||
Below are some examples:
|
||||
|
||||
Task: Which exercises are best for reducing belly fat at home?
|
||||
Output:
|
||||
- Lying Leg Raises
|
||||
- Leg In And Out
|
||||
- Plank
|
||||
- Side Plank
|
||||
- Sit-ups
|
||||
|
||||
Task: Extract all the country names in the paragraph, list them separated by commas.
|
||||
Example 1
|
||||
Paragraph: Dr. No is the sixth novel by the English author Ian Fleming to feature his British Secret Service agent James Bond. Written at Fleming's Goldeneye estate in Jamaica, it was first published in the United Kingdom by Jonathan Cape in 1958. In the novel Bond looks into the disappearance in Jamaica of two fellow MI6 operatives who had been investigating Doctor No. Bond travels to No's Caribbean island and meets Honeychile Rider, who is there to collect shells. They are captured and taken to a luxurious facility carved into a mountain. The character of Doctor No, the son of a German missionary and a Chinese woman, was influenced by Sax Rohmer's Fu Manchu stories. Dr. No was the first of Fleming's novels to face widespread negative reviews in Britain, but it was received more favourably in the United States.
|
||||
Output: English, British, Jamaica, the United Kingdom, German, Chinese, Britain, the United States.
|
||||
|
||||
Task: Converting 85 F to Celsius.
|
||||
Output: 85°F = 29.44°C
|
||||
|
||||
Task: Sort the given list ascendingly.
|
||||
Example 1
|
||||
List: [10, 92, 2, 5, -4, 92, 5, 101]
|
||||
Output: [-4, 2, 5, 5, 10, 92, 92, 101]
|
||||
Example 2
|
||||
Input 2 - List: [9.99, 10, -5, -1000, 5e6, 999]
|
||||
Output: [-1000, -5, 9.99, 10, 999, 5e6]
|
||||
|
||||
Task: Suggest a better and more professional rephrasing of the following sentence.
|
||||
Example 1
|
||||
Sentence: This house is surprisingly not constructed very well, and you probably need more money to fix it after you buy it. If you ask me, I would suggest you to consider other candidates.
|
||||
Output: This house does not seem to be constructed well, so you may need to spend more money to fix it after you purchase it. I would suggest that you look at other properties.
|
||||
Example 2
|
||||
Sentence: Just so you know, we did an experiment last week and found really surprising results - language model can improve itself!
|
||||
Output: Our experiments last week demonstrated surprising results, proving that the language model can improve itself.
|
||||
|
||||
Task: Read the following paragraph and answer a math question about the paragraph. You need to write out the calculation for getting the final answer.
|
||||
Example 1
|
||||
Paragraph: Gun violence in the United States results in tens of thousands of deaths and injuries annually, and was the leading cause of death for children 19 and younger in 2020. In 2018, the most recent year for which data are available as of 2021, the Centers for Disease Control and Prevention's (CDC) National Center for Health Statistics reports 38,390 deaths by firearm, of which 24,432 were by suicide. The rate of firearm deaths per 100,000 people rose from 10.3 per 100,000 in 1999 to 12 per 100,000 in 2017, with 109 people dying per day or about 14,542 homicides in total, being 11.9 per 100,000 in 2018. In 2010, there were 19,392 firearm-related suicides, and 11,078 firearm-related homicides in the U.S. In 2010, 358 murders were reported involving a rifle while 6,009 were reported involving a handgun; another 1,939 were reported with an unspecified type of firearm. In 2011, a total of 478,400 fatal and nonfatal violent crimes were committed with a firearm.
|
||||
Question: How many more firearm-related deaths were there in 2018 compared to 2010?
|
||||
Output:
|
||||
38390 - (19392 + 11078) = 38390 - 30470 = 7920.
|
||||
So, in 2018, there were 7920 more deaths by firearm than in 2010.
|
||||
|
||||
Task: Write Python code to solve this leetcode problem.
|
||||
Example 1
|
||||
Problem: You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order, and each of their nodes contains a single digit. Add the two numbers and return the sum as a linked list. You may assume the two numbers do not contain any leading zero, except the number 0 itself.
|
||||
Output:
|
||||
class Solution(object):
|
||||
def addTwoNumbers(self, l1, l2):
|
||||
carry = 0
|
||||
root = n = ListNode(0)
|
||||
while l1 or l2 or carry:
|
||||
v1 = v2 = 0
|
||||
if l1:
|
||||
v1 = l1.val
|
||||
l1 = l1.next
|
||||
if l2:
|
||||
v2 = l2.val
|
||||
l2 = l2.next
|
||||
carry, val = divmod(v1+v2+carry, 10)
|
||||
n.next = ListNode(val)
|
||||
n = n.next
|
||||
return root.next
|
||||
|
||||
Task: Solve the equation and find the value of X. Show your steps.
|
||||
Example 1
|
||||
Equation: 10X + 5 = 10
|
||||
Output: 10X = 5, X = 0.5
|
||||
Example 2
|
||||
Equation: X + Y + 120 = 100
|
||||
Output: X + Y = -20, X = -20 - Y
|
||||
|
||||
Task: Write a program to compute the sum of integers from k to n.
|
||||
Output:
|
||||
def sum(k, n):
|
||||
sum = 0
|
||||
for i in range(k, n+1):
|
||||
sum += i
|
||||
return sum
|
||||
|
||||
Task: Select the oldest person from the given list.
|
||||
Example 1
|
||||
List: George Washington, Confucius, Michael Jordan, Michelangelo
|
||||
Output: Confucious
|
||||
Example 2
|
||||
List: Alan Turing, Geoffrey Hinton, Yann LeCun, Yoshua Bengio
|
||||
Output: Alan Turing
|
||||
|
||||
Task: Turn down a job offer by sending an email to a recruiter explaining the reason.
|
||||
Output: Hi [Recruiter],
|
||||
Thank you so much for the generous offer to join your team. As we discussed, I’ve admired the company for a number of years, and am a proud endorser of its products. However, after further consideration of where I currently am in my career, I’ve decided to accept an offer at another company.
|
||||
I would love to stay in touch with you and have already started following you on [Social Media Platform]. Again, thank you so much for your time and consideration.
|
||||
Thanks again,
|
||||
[Your Name]
|
||||
|
||||
Task: {instruction}
|
||||
'''
|
||||
31
camel/datagen/source2synth/__init__.py
Normal file
31
camel/datagen/source2synth/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .data_processor import (
|
||||
DataCurator,
|
||||
ExampleConstructor,
|
||||
UserDataProcessor,
|
||||
)
|
||||
from .models import MultiHopQA, ReasoningStep
|
||||
from .user_data_processor_config import (
|
||||
ProcessorConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DataCurator",
|
||||
"ExampleConstructor",
|
||||
"ProcessorConfig",
|
||||
"UserDataProcessor",
|
||||
"ReasoningStep",
|
||||
"MultiHopQA",
|
||||
]
|
||||
538
camel/datagen/source2synth/data_processor.py
Normal file
538
camel/datagen/source2synth/data_processor.py
Normal file
@@ -0,0 +1,538 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from camel.agents.multi_hop_generator_agent import MultiHopGeneratorAgent
|
||||
from camel.datagen.source2synth.user_data_processor_config import (
|
||||
ProcessorConfig,
|
||||
)
|
||||
from camel.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserDataProcessor:
|
||||
r"""A processor for generating multi-hop question-answer pairs from user
|
||||
data.
|
||||
|
||||
This class handles the processing of text data to generate multi-hop
|
||||
question-answer pairs using either an AI model or rule-based approaches.
|
||||
It manages the entire pipeline from text preprocessing to dataset curation.
|
||||
|
||||
Attributes:
|
||||
config (ProcessorConfig): Configuration for data processing parameters.
|
||||
rng (random.Random): Random number generator for reproducibility.
|
||||
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for
|
||||
generating QA pairs.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ProcessorConfig] = None):
|
||||
r"""Initialize the UserDataProcessor.
|
||||
|
||||
Args:
|
||||
config (Optional[ProcessorConfig], optional): Configuration for
|
||||
data processing. (default: :obj:`None`)
|
||||
"""
|
||||
self.config = config or ProcessorConfig()
|
||||
self.rng = random.Random(self.config.seed)
|
||||
self.multi_hop_agent = (
|
||||
self.config.hop_generating_agent
|
||||
if self.config.use_ai_model
|
||||
else None
|
||||
)
|
||||
|
||||
def process_text(
|
||||
self, text: str, source: str = "user_input"
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Process a single text to generate multi-hop QA pairs.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
source (str, optional): Source identifier for the text.
|
||||
(default: :obj:`"user_input"`)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of processed examples with QA pairs and
|
||||
metadata.
|
||||
"""
|
||||
# Convert text to standard format
|
||||
raw_data = [
|
||||
{
|
||||
'text': text,
|
||||
'source': source,
|
||||
}
|
||||
]
|
||||
|
||||
# Construct examples
|
||||
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
|
||||
examples = constructor.construct_examples(raw_data)
|
||||
|
||||
# Manage data
|
||||
curator = DataCurator(self.config, self.rng)
|
||||
final_dataset = curator.curate_dataset(examples)
|
||||
|
||||
return final_dataset
|
||||
|
||||
def process_batch(
|
||||
self, texts: List[str], sources: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Process multiple texts in batch to generate multi-hop QA pairs.
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of input texts to process.
|
||||
sources (Optional[List[str]], optional): List of source
|
||||
identifiers. (default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of processed examples with QA pairs and
|
||||
metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If length of sources doesn't match length of texts.
|
||||
"""
|
||||
if sources is None:
|
||||
sources = ["user_input"] * len(texts)
|
||||
elif len(sources) != len(texts):
|
||||
raise ValueError("Length of sources must match length of texts")
|
||||
|
||||
raw_data = [
|
||||
{
|
||||
'text': text,
|
||||
'source': source,
|
||||
}
|
||||
for text, source in zip(texts, sources)
|
||||
]
|
||||
|
||||
# Construct examples
|
||||
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
|
||||
examples = constructor.construct_examples(raw_data)
|
||||
|
||||
# Manage data
|
||||
curator = DataCurator(self.config, self.rng)
|
||||
final_dataset = curator.curate_dataset(examples)
|
||||
|
||||
return final_dataset
|
||||
|
||||
|
||||
class ExampleConstructor:
|
||||
r"""Constructs training examples from raw text data.
|
||||
|
||||
This class handles the construction of training examples by preprocessing
|
||||
text, extracting information pairs, and generating question-answer pairs.
|
||||
|
||||
Attributes:
|
||||
config (ProcessorConfig): Configuration for example construction.
|
||||
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for QA
|
||||
generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ProcessorConfig,
|
||||
multi_hop_agent: Optional[MultiHopGeneratorAgent] = None,
|
||||
):
|
||||
r"""Initialize the ExampleConstructor.
|
||||
|
||||
Args:
|
||||
config (ProcessorConfig): Configuration for example construction.
|
||||
multi_hop_agent (Optional[MultiHopGeneratorAgent], optional):
|
||||
Agent for generating multi-hop QA pairs. (default: :obj:`None`)
|
||||
"""
|
||||
self.config = config
|
||||
self.multi_hop_agent = multi_hop_agent
|
||||
|
||||
def construct_examples(
|
||||
self, raw_data: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Construct training examples from raw data.
|
||||
|
||||
Args:
|
||||
raw_data (List[Dict[str, Any]]): List of raw data dictionaries
|
||||
containing text and metadata.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of constructed examples with QA pairs
|
||||
and metadata.
|
||||
"""
|
||||
logger.info("Starting to construct training examples...")
|
||||
examples = []
|
||||
|
||||
for data in tqdm(raw_data, desc="Constructing examples"):
|
||||
# 1. Text preprocessing
|
||||
processed_text = self._preprocess_text(data.get('text', ''))
|
||||
if not processed_text:
|
||||
continue
|
||||
|
||||
# 2. Generate key information pairs
|
||||
info_pairs = self._extract_info_pairs(processed_text)
|
||||
|
||||
# 3. Construct question-answer pairs
|
||||
qa_pairs = self._generate_qa_pairs(info_pairs)
|
||||
|
||||
# 4. Add metadata
|
||||
example = {
|
||||
'text': processed_text,
|
||||
'qa_pairs': qa_pairs,
|
||||
'metadata': {
|
||||
'source': data.get('source', 'unknown'),
|
||||
'timestamp': data.get('timestamp', ''),
|
||||
'complexity': self._calculate_complexity(qa_pairs),
|
||||
},
|
||||
}
|
||||
|
||||
examples.append(example)
|
||||
|
||||
logger.info(f"Successfully constructed {len(examples)} examples")
|
||||
return examples
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
r"""Preprocess input text for example construction.
|
||||
|
||||
Args:
|
||||
text (str): Input text to preprocess.
|
||||
|
||||
Returns:
|
||||
str: Preprocessed text, or empty string if text fails quality
|
||||
checks.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return ''
|
||||
|
||||
# 1. Basic cleaning
|
||||
text = text.strip()
|
||||
|
||||
# 2. Length check
|
||||
if (
|
||||
len(text) < self.config.min_length
|
||||
or len(text) > self.config.max_length
|
||||
):
|
||||
return ''
|
||||
|
||||
# 3. Quality check
|
||||
if not self._check_text_quality(text):
|
||||
return ''
|
||||
|
||||
return text
|
||||
|
||||
def _check_text_quality(self, text: str) -> bool:
|
||||
r"""Check the quality of input text.
|
||||
|
||||
Args:
|
||||
text (str): Text to check quality for.
|
||||
|
||||
Returns:
|
||||
bool: True if text passes quality checks, False otherwise.
|
||||
"""
|
||||
# 1. Basic quality check
|
||||
if text.count('.') < 2: # Must have at least 2 sentences
|
||||
return False
|
||||
|
||||
# 2. Special character ratio check
|
||||
special_char_ratio = len(
|
||||
[c for c in text if not c.isalnum() and not c.isspace()]
|
||||
) / len(text)
|
||||
if special_char_ratio > 0.3: # No more than 30% special characters
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _extract_info_pairs(self, text: str) -> List[Dict[str, Sequence[str]]]:
|
||||
r"""Extract information pairs and relationships from text.
|
||||
|
||||
Args:
|
||||
text (str): Input text to extract information from.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Sequence[str]]]: List of dictionaries containing
|
||||
premise, intermediate, conclusion, and related contexts.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = [s.strip() for s in text.split('.') if s.strip()]
|
||||
info_pairs = []
|
||||
|
||||
# Extract combinations of multiple related sentences
|
||||
for i in range(len(sentences) - 2):
|
||||
if len(sentences[i]) > 10 and len(sentences[i + 1]) > 10:
|
||||
info_pairs.append(
|
||||
{
|
||||
'premise': sentences[i],
|
||||
'intermediate': sentences[i + 1],
|
||||
'conclusion': sentences[i + 2]
|
||||
if i + 2 < len(sentences)
|
||||
else '',
|
||||
'related_contexts': [
|
||||
s
|
||||
for j, s in enumerate(sentences)
|
||||
if j != i and j != i + 1 and len(s) > 10
|
||||
][:2],
|
||||
# Limit to 2 additional related contexts
|
||||
}
|
||||
)
|
||||
|
||||
return info_pairs
|
||||
|
||||
def _generate_qa_pairs(
|
||||
self, info_pairs: List[Dict[str, Sequence[str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""Generate multi-hop question-answer pairs from information pairs.
|
||||
|
||||
Args:
|
||||
info_pairs (List[Dict[str, Sequence[str]]]): List of information
|
||||
pairs extracted from text.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: List of generated QA pairs.
|
||||
"""
|
||||
qa_pairs = []
|
||||
|
||||
for pair in info_pairs:
|
||||
# 1. Generate multi-hop question-answer pair using AI
|
||||
if self.multi_hop_agent:
|
||||
# Construct full context
|
||||
context = (
|
||||
f"{pair['premise']}. {pair['intermediate']}."
|
||||
f" {pair['conclusion']}"
|
||||
)
|
||||
response = self.multi_hop_agent.generate_multi_hop_qa(context)
|
||||
if response:
|
||||
qa_pairs.append(response.value.dict())
|
||||
continue
|
||||
|
||||
return qa_pairs
|
||||
|
||||
def _calculate_complexity(self, qa_pairs: List[Dict[str, Any]]) -> float:
|
||||
r"""Calculate the complexity score for a set of QA pairs.
|
||||
|
||||
Args:
|
||||
qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate
|
||||
complexity for.
|
||||
|
||||
Returns:
|
||||
float: Complexity score between 0.0 and 1.0.
|
||||
"""
|
||||
if not qa_pairs:
|
||||
return 0.0
|
||||
|
||||
# Calculate complexity based on multiple factors
|
||||
complexities = []
|
||||
for qa in qa_pairs:
|
||||
# 1. Number of reasoning steps
|
||||
reasoning_steps_count = len(qa.get('reasoning_steps', []))
|
||||
|
||||
# 2. Number of supporting facts
|
||||
supporting_facts_count = len(qa.get('supporting_facts', []))
|
||||
|
||||
# 3. Question length
|
||||
question_length = len(qa.get('question', '').split())
|
||||
|
||||
# 4. Answer length
|
||||
answer_length = len(qa.get('answer', '').split())
|
||||
|
||||
# Calculate complexity of a single QA pair
|
||||
qa_complexity = (
|
||||
min(reasoning_steps_count / 3, 1.0)
|
||||
* 0.4 # Weight for reasoning steps
|
||||
+ min(supporting_facts_count / 3, 1.0)
|
||||
* 0.3 # Weight for supporting facts
|
||||
+ min(question_length / 20, 1.0)
|
||||
* 0.15 # Weight for question length
|
||||
+ min(answer_length / 50, 1.0) * 0.15
|
||||
# Weight for answer length
|
||||
)
|
||||
|
||||
complexities.append(qa_complexity)
|
||||
|
||||
return sum(complexities) / len(complexities)
|
||||
|
||||
|
||||
class DataCurator:
|
||||
r"""Manages and curates datasets of multi-hop question-answer pairs.
|
||||
|
||||
This class handles dataset management tasks including quality filtering,
|
||||
complexity filtering, deduplication, and dataset sampling.
|
||||
|
||||
Attributes:
|
||||
config (ProcessorConfig): Configuration for data curation parameters.
|
||||
rng (random.Random): Random number generator for reproducible sampling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ProcessorConfig, rng: random.Random):
|
||||
r"""Initialize the DataCurator.
|
||||
|
||||
Args:
|
||||
config (ProcessorConfig): Configuration for data curation.
|
||||
rng (random.Random): Random number generator for reproducibility.
|
||||
"""
|
||||
self.config = config
|
||||
self.rng = rng
|
||||
|
||||
def curate_dataset(
|
||||
self, examples: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Manage and curate a dataset through multiple filtering stages.
|
||||
|
||||
Args:
|
||||
examples (List[Dict[str, Any]]): List of examples to curate.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Curated dataset meeting quality criteria.
|
||||
"""
|
||||
logger.info("Starting dataset management...")
|
||||
|
||||
# 1. Quality filtering
|
||||
quality_filtered = self._quality_filter(examples)
|
||||
logger.info(
|
||||
f"Remaining examples after quality filtering:"
|
||||
f" {len(quality_filtered)}"
|
||||
)
|
||||
|
||||
# 2. Complexity filtering
|
||||
complexity_filtered = self._complexity_filter(quality_filtered)
|
||||
logger.info(
|
||||
f"Remaining examples after complexity filtering:"
|
||||
f" {len(complexity_filtered)}"
|
||||
)
|
||||
|
||||
# 3. Deduplication
|
||||
deduplicated = self._remove_duplicates(complexity_filtered)
|
||||
logger.info(
|
||||
f"Remaining examples after deduplication: {len(deduplicated)}"
|
||||
)
|
||||
|
||||
# 4. Sample to target size
|
||||
final_dataset = self._sample_dataset(deduplicated)
|
||||
logger.info(f"Final dataset size: {len(final_dataset)}")
|
||||
|
||||
return final_dataset
|
||||
|
||||
def _quality_filter(
|
||||
self, examples: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Filter examples based on quality criteria.
|
||||
|
||||
Args:
|
||||
examples (List[Dict[str, Any]]): List of examples to filter.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Examples that pass quality checks.
|
||||
"""
|
||||
filtered = []
|
||||
|
||||
for example in examples:
|
||||
# 1. Check QA pair quality
|
||||
qa_quality = self._check_qa_quality(example.get('qa_pairs', []))
|
||||
|
||||
# 2. Check text quality
|
||||
text_quality = (
|
||||
len(example.get('text', '').split()) >= 20
|
||||
) # At least 20 words
|
||||
|
||||
if qa_quality and text_quality:
|
||||
filtered.append(example)
|
||||
|
||||
return filtered
|
||||
|
||||
def _check_qa_quality(self, qa_pairs: List[Dict[str, str]]) -> bool:
|
||||
r"""Check the quality of question-answer pairs.
|
||||
|
||||
Args:
|
||||
qa_pairs (List[Dict[str, str]]): List of QA pairs to check.
|
||||
|
||||
Returns:
|
||||
bool: True if QA pairs meet quality criteria, False otherwise.
|
||||
"""
|
||||
if not qa_pairs:
|
||||
return False
|
||||
|
||||
for qa in qa_pairs:
|
||||
# 1. Length check
|
||||
if (
|
||||
len(qa.get('question', '')) < 10
|
||||
or len(qa.get('answer', '')) < 5
|
||||
):
|
||||
return False
|
||||
|
||||
# 2. QA pair duplication check
|
||||
if qa.get('question', '') == qa.get('answer', ''):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _complexity_filter(
|
||||
self, examples: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filter examples based on complexity threshold.
|
||||
|
||||
Removes examples with complexity scores below the configured threshold.
|
||||
|
||||
Args:
|
||||
examples (List[Dict[str, Any]]): List of examples to filter.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Examples meeting complexity threshold.
|
||||
"""
|
||||
return [
|
||||
example
|
||||
for example in examples
|
||||
if example.get('metadata', {}).get('complexity', 0)
|
||||
>= self.config.complexity_threshold
|
||||
]
|
||||
|
||||
def _remove_duplicates(
|
||||
self, examples: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Remove duplicate examples from the dataset.
|
||||
|
||||
Args:
|
||||
examples (List[Dict[str, Any]]): List of examples to deduplicate.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Deduplicated examples.
|
||||
"""
|
||||
seen = set()
|
||||
unique_examples = []
|
||||
|
||||
for example in examples:
|
||||
# Use text and QA pair combination as unique identifier
|
||||
text = example.get('text', '')
|
||||
qa_str = str(example.get('qa_pairs', []))
|
||||
|
||||
identifier = hash(text + qa_str)
|
||||
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_examples.append(example)
|
||||
|
||||
return unique_examples
|
||||
|
||||
def _sample_dataset(
|
||||
self, examples: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Sample examples to match target dataset size.
|
||||
|
||||
Args:
|
||||
examples (List[Dict[str, Any]]): List of examples to sample from.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Sampled dataset of target size or smaller.
|
||||
"""
|
||||
if len(examples) <= self.config.dataset_size:
|
||||
return examples
|
||||
|
||||
return self.rng.sample(examples, self.config.dataset_size)
|
||||
93
camel/datagen/source2synth/models.py
Normal file
93
camel/datagen/source2synth/models.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, ClassVar, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReasoningStep(BaseModel):
|
||||
r"""A single step in a multi-hop reasoning process.
|
||||
|
||||
Attributes:
|
||||
step (str): The textual description of the reasoning step.
|
||||
"""
|
||||
|
||||
step: str = Field(
|
||||
..., description="A single step in the reasoning process."
|
||||
)
|
||||
|
||||
|
||||
class MultiHopQA(BaseModel):
|
||||
r"""A multi-hop question-answer pair with reasoning steps and supporting
|
||||
facts.
|
||||
|
||||
Attributes:
|
||||
question (str): The question requiring multi-hop reasoning.
|
||||
reasoning_steps (List[ReasoningStep]): List of reasoning steps to
|
||||
answer.
|
||||
answer (str): The final answer to the question.
|
||||
supporting_facts (List[str]): List of facts supporting the reasoning.
|
||||
type (str): The type of question-answer pair.
|
||||
"""
|
||||
|
||||
question: str = Field(
|
||||
..., description="The question that requires multi-hop reasoning."
|
||||
)
|
||||
reasoning_steps: List[ReasoningStep] = Field(
|
||||
...,
|
||||
description="The steps involved in reasoning to answer the question.",
|
||||
)
|
||||
answer: str = Field(
|
||||
..., description="The answer to the multi-hop question."
|
||||
)
|
||||
supporting_facts: List[str] = Field(
|
||||
..., description="Facts that support the reasoning and answer."
|
||||
)
|
||||
type: str = Field(description="The type of question-answer pair.")
|
||||
|
||||
class Config:
|
||||
json_schema_extra: ClassVar[Dict[str, Any]] = {
|
||||
"example": {
|
||||
"question": "What is the capital of France?",
|
||||
"reasoning_steps": [
|
||||
{"step": "Identify the country France."},
|
||||
{"step": "Find the capital city of France."},
|
||||
],
|
||||
"answer": "Paris",
|
||||
"supporting_facts": [
|
||||
"France is a country in Europe.",
|
||||
"Paris is the capital city of France.",
|
||||
],
|
||||
"type": "multi_hop_qa",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ContextPrompt(BaseModel):
|
||||
r"""A context prompt for generating multi-hop question-answer pairs.
|
||||
|
||||
Attributes:
|
||||
main_context (str): The primary context for generating QA pairs.
|
||||
related_contexts (Optional[List[str]]): Additional related contexts.
|
||||
"""
|
||||
|
||||
main_context: str = Field(
|
||||
...,
|
||||
description="The main context for generating"
|
||||
" the question-answer pair.",
|
||||
)
|
||||
related_contexts: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Additional contexts related to the main context.",
|
||||
)
|
||||
74
camel/datagen/source2synth/user_data_processor_config.py
Normal file
74
camel/datagen/source2synth/user_data_processor_config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import random
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from camel.agents.multi_hop_generator_agent import MultiHopGeneratorAgent
|
||||
|
||||
|
||||
class ProcessorConfig(BaseModel):
|
||||
r"""Data processing configuration class"""
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ProcessorConfig("
|
||||
f"seed={self.seed}, min_length={self.min_length}, "
|
||||
f"max_length={self.max_length}, "
|
||||
f"complexity_threshold={self.complexity_threshold}, "
|
||||
f"dataset_size={self.dataset_size}, "
|
||||
f"use_ai_model={self.use_ai_model}"
|
||||
f")"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
frozen=False,
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
seed: int = Field( # Generate a random seed for reproducibility
|
||||
default_factory=lambda: random.randint(0, 1000),
|
||||
description="Random seed for reproducibility",
|
||||
)
|
||||
|
||||
min_length: int = Field(
|
||||
default=50, description="Minimum text length", ge=0
|
||||
)
|
||||
|
||||
max_length: int = Field(
|
||||
default=512, description="Maximum text length", gt=0
|
||||
)
|
||||
|
||||
complexity_threshold: float = Field(
|
||||
default=0.5,
|
||||
description="Complexity threshold for processing",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
dataset_size: int = Field(
|
||||
default=1000, description="Target size of the dataset", gt=0
|
||||
)
|
||||
|
||||
use_ai_model: bool = Field(
|
||||
default=True, description="Whether to use AI model in processing"
|
||||
)
|
||||
|
||||
hop_generating_agent: MultiHopGeneratorAgent = Field(
|
||||
default_factory=lambda: MultiHopGeneratorAgent(),
|
||||
description="Agent for generating multi-hop text",
|
||||
)
|
||||
23
camel/datahubs/__init__.py
Normal file
23
camel/datahubs/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
from .base import BaseDatasetManager
|
||||
from .huggingface import HuggingFaceDatasetManager
|
||||
from .models import Record
|
||||
|
||||
__all__ = [
|
||||
"BaseDatasetManager",
|
||||
"Record",
|
||||
"HuggingFaceDatasetManager",
|
||||
]
|
||||
136
camel/datahubs/base.py
Normal file
136
camel/datahubs/base.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from camel.datahubs.models import Record
|
||||
|
||||
|
||||
class BaseDatasetManager(ABC):
|
||||
r"""Abstract base class for dataset managers."""
|
||||
|
||||
@abstractmethod
|
||||
def create_dataset(self, name: str, **kwargs: Any) -> str:
|
||||
r"""Creates a new dataset.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default::obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
# New method for record deletion
|
||||
@abstractmethod
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
(default::obj:`"records/records.json"`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
pass
|
||||
444
camel/datahubs/huggingface.py
Normal file
444
camel/datahubs/huggingface.py
Normal file
@@ -0,0 +1,444 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from camel.datahubs.base import BaseDatasetManager
|
||||
from camel.datahubs.models import Record
|
||||
from camel.logger import get_logger
|
||||
from camel.types import HuggingFaceRepoType
|
||||
from camel.utils import api_keys_required, dependencies_required
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingFaceDatasetManager(BaseDatasetManager):
|
||||
r"""A dataset manager for Hugging Face datasets. This class provides
|
||||
methods to create, add, update, delete, and list records in a dataset on
|
||||
the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
token (str): The Hugging Face API token. If not provided, the token
|
||||
will be read from the environment variable `HF_TOKEN`.
|
||||
"""
|
||||
|
||||
@api_keys_required(
|
||||
[
|
||||
("token", "HF_TOKEN"),
|
||||
]
|
||||
)
|
||||
@dependencies_required('huggingface_hub')
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
self._api_key = token or os.getenv("HF_TOKEN")
|
||||
self.api = HfApi(token=self._api_key)
|
||||
|
||||
def create_dataset_card(
|
||||
self,
|
||||
dataset_name: str,
|
||||
description: str,
|
||||
license: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
authors: Optional[List[str]] = None,
|
||||
size_category: Optional[List[str]] = None,
|
||||
language: Optional[List[str]] = None,
|
||||
task_categories: Optional[List[str]] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
|
||||
format.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
description (str): A description of the dataset.
|
||||
license (str): The license of the dataset. (default: :obj:`None`)
|
||||
version (str): The version of the dataset. (default: :obj:`None`)
|
||||
tags (list): A list of tags for the dataset.(default: :obj:`None`)
|
||||
authors (list): A list of authors of the dataset. (default:
|
||||
:obj:`None`)
|
||||
size_category (list): A size category for the dataset. (default:
|
||||
:obj:`None`)
|
||||
language (list): A list of languages the dataset is in. (default:
|
||||
:obj:`None`)
|
||||
task_categories (list): A list of task categories. (default:
|
||||
:obj:`None`)
|
||||
content (str): Custom markdown content that the user wants to add
|
||||
to the dataset card. (default: :obj:`None`)
|
||||
"""
|
||||
import yaml
|
||||
|
||||
metadata = {
|
||||
"license": license,
|
||||
"authors": authors,
|
||||
"task_categories": task_categories,
|
||||
"language": language,
|
||||
"tags": tags,
|
||||
"pretty_name": dataset_name,
|
||||
"size_categories": size_category,
|
||||
"version": version,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
# Remove keys with None values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
card_content = (
|
||||
"---\n"
|
||||
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
|
||||
+ "\n---"
|
||||
)
|
||||
|
||||
if content:
|
||||
card_content += f"\n\n# Additional Information\n{content}\n"
|
||||
|
||||
self._upload_file(
|
||||
file_content=card_content,
|
||||
dataset_name=dataset_name,
|
||||
filepath="README.md",
|
||||
file_type="md",
|
||||
)
|
||||
|
||||
def create_dataset(
|
||||
self, name: str, private: bool = False, **kwargs: Any
|
||||
) -> str:
|
||||
r"""Creates a new dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
private (bool): Whether the dataset should be private. defaults to
|
||||
False.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset.
|
||||
"""
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
self.api.repo_info(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
self.api.create_repo(
|
||||
repo_id=name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
private=private,
|
||||
)
|
||||
|
||||
return f"https://huggingface.co/datasets/{name}"
|
||||
|
||||
def list_datasets(
|
||||
self, username: str, limit: int = 100, **kwargs: Any
|
||||
) -> List[str]:
|
||||
r"""Lists all datasets for the current user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user whose datasets to list.
|
||||
limit (int): The maximum number of datasets to list.
|
||||
(default: :obj:`100`)
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of dataset ids.
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
dataset.id
|
||||
for dataset in self.api.list_datasets(
|
||||
author=username, limit=limit, **kwargs
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing datasets: {e}")
|
||||
return []
|
||||
|
||||
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
|
||||
r"""Deletes a dataset from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset to delete.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
try:
|
||||
self.api.delete_repo(
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
|
||||
raise
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Adds records to a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to add to the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset already has a records file.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{filepath}' already exists. "
|
||||
f"Use `update_records` to modify."
|
||||
)
|
||||
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
records: List[Record],
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Updates records in a dataset on the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
records (List[Record]): A list of records to update in the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to update
|
||||
records in.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
logger.warning(
|
||||
f"Dataset '{dataset_name}' does not have existing "
|
||||
"records. Adding new records."
|
||||
)
|
||||
self._upload_records(
|
||||
records=records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
old_dict = {record.id: record for record in existing_records}
|
||||
new_dict = {record.id: record for record in records}
|
||||
merged_dict = old_dict.copy()
|
||||
merged_dict.update(new_dict)
|
||||
|
||||
self._upload_records(
|
||||
records=list(merged_dict.values()),
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete_record(
|
||||
self,
|
||||
dataset_name: str,
|
||||
record_id: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Deletes a record from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
record_id (str): The ID of the record to delete.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset does not have an existing file to delete
|
||||
records from.
|
||||
"""
|
||||
existing_records = self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
if not existing_records:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' does not have an existing file to "
|
||||
f"delete records from."
|
||||
)
|
||||
|
||||
filtered_records = [
|
||||
record for record in existing_records if record.id != record_id
|
||||
]
|
||||
|
||||
self._upload_records(
|
||||
records=filtered_records,
|
||||
dataset_name=dataset_name,
|
||||
filepath=filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def list_records(
|
||||
self,
|
||||
dataset_name: str,
|
||||
filepath: str = "records/records.json",
|
||||
**kwargs: Any,
|
||||
) -> List[Record]:
|
||||
r"""Lists all records in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
filepath (str): The path to the file containing the records.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Record]: A list of records in the dataset.
|
||||
"""
|
||||
return self._download_records(
|
||||
dataset_name=dataset_name, filepath=filepath, **kwargs
|
||||
)
|
||||
|
||||
def _download_records(
|
||||
self, dataset_name: str, filepath: str, **kwargs: Any
|
||||
) -> List[Record]:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
|
||||
try:
|
||||
downloaded_file_path = hf_hub_download(
|
||||
repo_id=dataset_name,
|
||||
filename=filepath,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
token=self._api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with open(downloaded_file_path, "r") as f:
|
||||
records_data = json.load(f)
|
||||
|
||||
return [Record(**record) for record in records_data]
|
||||
except EntryNotFoundError:
|
||||
logger.info(f"No records found for dataset '{dataset_name}'.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading or processing records: {e}")
|
||||
raise e
|
||||
|
||||
def _upload_records(
|
||||
self,
|
||||
records: List[Record],
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, mode="w", newline="", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump(
|
||||
[
|
||||
record.model_dump(exclude_defaults=True)
|
||||
for record in records
|
||||
],
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading records file: {e}")
|
||||
raise
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def _upload_file(
|
||||
self,
|
||||
file_content: str,
|
||||
dataset_name: str,
|
||||
filepath: str,
|
||||
file_type: str = "json",
|
||||
**kwargs: Any,
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=f".{file_type}"
|
||||
) as f:
|
||||
if file_type == "json":
|
||||
if isinstance(file_content, str):
|
||||
try:
|
||||
json_content = json.loads(file_content)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"Invalid JSON string provided for file_content."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
json.dumps(file_content, ensure_ascii=False)
|
||||
json_content = file_content
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"file_content is not JSON serializable."
|
||||
)
|
||||
|
||||
json.dump(json_content, f, ensure_ascii=False)
|
||||
elif file_type == "md" or file_type == "txt":
|
||||
f.write(file_content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
self.api.upload_file(
|
||||
path_or_fileobj=temp_file_path,
|
||||
path_in_repo=filepath,
|
||||
repo_id=dataset_name,
|
||||
repo_type=HuggingFaceRepoType.DATASET.value,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"File uploaded successfully: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file: {e}")
|
||||
raise
|
||||
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
24
camel/datahubs/models.py
Normal file
24
camel/datahubs/models.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
content: Optional[Dict[str, Any]] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
26
camel/datasets/__init__.py
Normal file
26
camel/datasets/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .base_generator import BaseGenerator
|
||||
from .few_shot_generator import FewShotGenerator
|
||||
from .models import DataPoint
|
||||
from .self_instruct_generator import SelfInstructGenerator
|
||||
from .static_dataset import StaticDataset
|
||||
|
||||
__all__ = [
|
||||
"BaseGenerator",
|
||||
"DataPoint",
|
||||
"FewShotGenerator",
|
||||
"StaticDataset",
|
||||
"SelfInstructGenerator",
|
||||
]
|
||||
292
camel/datasets/base_generator.py
Normal file
292
camel/datasets/base_generator.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from camel.logger import get_logger
|
||||
|
||||
from .models import DataPoint
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseGenerator(abc.ABC, IterableDataset):
|
||||
r"""Abstract base class for data generators.
|
||||
|
||||
This class defines the interface for generating synthetic datapoints.
|
||||
Concrete implementations should provide specific generation strategies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
buffer: int = 20,
|
||||
cache: Union[str, Path, None] = None,
|
||||
data_path: Union[str, Path, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initialize the base generator.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
||||
buffer (int): Amount of DataPoints to be generated when the
|
||||
iterator runs out of DataPoints in data. (default: :obj:`20`)
|
||||
cache (Union[str, Path, None]): Optional path to save generated
|
||||
datapoints during iteration. If None is provided, datapoints
|
||||
will be discarded every 100 generations.
|
||||
data_path (Union[str, Path, None]): Optional path to a JSONL file
|
||||
to initialize the dataset from.
|
||||
**kwargs: Additional generator parameters.
|
||||
"""
|
||||
self._rng = random.Random(seed)
|
||||
self.cache = Path(cache) if cache else None
|
||||
self._buffer = buffer
|
||||
self._data: List[DataPoint] = []
|
||||
self._batch_to_save: List[DataPoint] = []
|
||||
|
||||
if data_path:
|
||||
file_path = Path(data_path)
|
||||
raw_data = self._init_from_jsonl(file_path)
|
||||
try:
|
||||
data_points = [DataPoint(**item) for item in raw_data]
|
||||
self._data.extend(data_points)
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"Failed to create DataPoint from JSONL data: {e}"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate_new(self, n: int, **kwargs) -> None:
|
||||
r"""Generate n new datapoints and append them to self._data.
|
||||
|
||||
Subclass implementations must generate the specified number of
|
||||
datapoints and append them directly to the `self._data` list.
|
||||
This method should not return the datapoints; the iterator
|
||||
relies on `self._data` being populated.
|
||||
|
||||
Args:
|
||||
n (int): Number of datapoints to generate and append.
|
||||
**kwargs: Additional generation parameters.
|
||||
|
||||
Returns:
|
||||
None: This method should not return anything.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def generate_new(self, n: int, **kwargs) -> None:
|
||||
new_points = [DataPoint(...) for _ in range(n)]
|
||||
self._data.extend(new_points)
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
def __aiter__(self):
|
||||
r"""Async iterator that yields datapoints dynamically.
|
||||
|
||||
If a `data_path` was provided during initialization, those datapoints
|
||||
are yielded first. When self._data is empty, 20 new datapoints
|
||||
are generated. Every 100 yields, the batch is appended to the
|
||||
JSONL file or discarded if `cache` is None.
|
||||
|
||||
Yields:
|
||||
DataPoint: A single datapoint.
|
||||
"""
|
||||
|
||||
async def generator():
|
||||
while True:
|
||||
if not self._data:
|
||||
await self.generate_new(self._buffer)
|
||||
datapoint = self._data.pop(0)
|
||||
yield datapoint
|
||||
self._batch_to_save.append(datapoint)
|
||||
if len(self._batch_to_save) == 100:
|
||||
if self.cache:
|
||||
with self.cache.open("a", encoding="utf-8") as f:
|
||||
for dp in self._batch_to_save:
|
||||
json.dump(dp.to_dict(), f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
self._batch_to_save = []
|
||||
|
||||
return generator()
|
||||
|
||||
def __iter__(self):
|
||||
r"""Synchronous iterator for PyTorch IterableDataset compatibility.
|
||||
|
||||
If a `data_path` was provided during initialization, those datapoints
|
||||
are yielded first. When self._data is empty, 20 new datapoints
|
||||
are generated. Every 100 yields, the batch is appended to the
|
||||
JSONL file or discarded if `cache` is None.
|
||||
|
||||
Yields:
|
||||
DataPoint: A single datapoint.
|
||||
"""
|
||||
try:
|
||||
if asyncio.get_event_loop().is_running():
|
||||
raise RuntimeError(
|
||||
"Cannot use synchronous iteration (__iter__) in an async "
|
||||
"context; use 'async for' with __aiter__ instead"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" not in str(e):
|
||||
raise
|
||||
|
||||
while True:
|
||||
if not self._data:
|
||||
asyncio.run(self.generate_new(self._buffer))
|
||||
datapoint = self._data.pop(0)
|
||||
yield datapoint
|
||||
self._batch_to_save.append(datapoint)
|
||||
if len(self._batch_to_save) == 100:
|
||||
if self.cache:
|
||||
with self.cache.open("a", encoding="utf-8") as f:
|
||||
for dp in self._batch_to_save:
|
||||
json.dump(dp.to_dict(), f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
self._batch_to_save = []
|
||||
|
||||
def sample(self) -> DataPoint:
|
||||
r"""Returns the next datapoint from the current dataset
|
||||
synchronously.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called in an async context.
|
||||
|
||||
Returns:
|
||||
DataPoint: The next DataPoint.
|
||||
|
||||
Note:
|
||||
This method is intended for synchronous contexts.
|
||||
Use 'async_sample' in asynchronous contexts to
|
||||
avoid blocking or runtime errors.
|
||||
"""
|
||||
try:
|
||||
if asyncio.get_event_loop().is_running():
|
||||
raise RuntimeError(
|
||||
"Cannot use synchronous sampling (sample) "
|
||||
"in an async context; use async_sample instead"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" not in str(e):
|
||||
raise
|
||||
|
||||
return next(iter(self))
|
||||
|
||||
async def async_sample(self) -> DataPoint:
|
||||
r"""Returns the next datapoint from the current dataset asynchronously.
|
||||
|
||||
Returns:
|
||||
DataPoint: The next datapoint.
|
||||
|
||||
Note:
|
||||
This method is intended for asynchronous contexts. Use 'sample'
|
||||
in synchronous contexts.
|
||||
"""
|
||||
|
||||
async_iter = self.__aiter__()
|
||||
return await async_iter.__anext__()
|
||||
|
||||
def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
|
||||
r"""Saves the generated datapoints to a JSONL (JSON Lines) file.
|
||||
|
||||
Each datapoint is stored as a separate JSON object on a new line.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): Path to save the JSONL file.
|
||||
|
||||
Raises:
|
||||
ValueError: If no datapoints have been generated.
|
||||
IOError: If there is an issue writing to the file.
|
||||
|
||||
Notes:
|
||||
- Uses `self._data`, which contains the generated datapoints.
|
||||
- Appends to the file if it already exists.
|
||||
- Ensures compatibility with large datasets by using JSONL format.
|
||||
"""
|
||||
if not self._data:
|
||||
raise ValueError("Dataset is empty. No data to save.")
|
||||
|
||||
file_path = Path(file_path)
|
||||
|
||||
try:
|
||||
with file_path.open("a", encoding="utf-8") as f:
|
||||
for datapoint in self._data:
|
||||
json.dump(datapoint.to_dict(), f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
logger.info(f"Dataset saved successfully to {file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"Error writing to file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def flush(self, file_path: Union[str, Path]) -> None:
|
||||
r"""Flush the current data to a JSONL file and clear the data.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): Path to save the JSONL file.
|
||||
|
||||
Notes:
|
||||
- Uses `save_to_jsonl` to save `self._data`.
|
||||
"""
|
||||
|
||||
self.save_to_jsonl(file_path)
|
||||
self._data = []
|
||||
logger.info(f"Data flushed to {file_path} and cleared from the memory")
|
||||
|
||||
def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
|
||||
r"""Load and parse a dataset from a JSONL file.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the JSONL file.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of datapoint dictionaries.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified JSONL file does not exist.
|
||||
ValueError: If a line contains invalid JSON or is not a dictionary.
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"JSONL file not found: {file_path}")
|
||||
|
||||
raw_data = []
|
||||
logger.debug(f"Loading JSONL from {file_path}")
|
||||
with file_path.open('r', encoding='utf-8') as f:
|
||||
for line_number, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue # Skip blank lines
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid JSON on line {line_number} "
|
||||
f"in file {file_path}: {e}"
|
||||
)
|
||||
if not isinstance(record, dict):
|
||||
raise ValueError(
|
||||
f"Expected a dictionary at line {line_number}, "
|
||||
f"got {type(record).__name__}"
|
||||
)
|
||||
raw_data.append(record)
|
||||
logger.info(
|
||||
f"Successfully loaded {len(raw_data)} items from {file_path}"
|
||||
)
|
||||
return raw_data
|
||||
282
camel/datasets/few_shot_generator.py
Normal file
282
camel/datasets/few_shot_generator.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.models.base_model import BaseModelBackend
|
||||
from camel.verifiers import BaseVerifier
|
||||
|
||||
from .base_generator import BaseGenerator
|
||||
from .models import DataPoint
|
||||
from .static_dataset import StaticDataset
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """**You are an advanced data generation assistant.**
|
||||
Your goal is to generate high-quality synthetic data points based on
|
||||
provided examples. Your output must be well-structured,
|
||||
logically sound, and formatted correctly.
|
||||
|
||||
**Instructions:**
|
||||
1. **Follow the Structure**
|
||||
Each data point must include:
|
||||
- **Question**: A clear, well-formed query.
|
||||
- **Rationale**: A step-by-step, executable reasoning process ending
|
||||
with `print(final_answer)`.
|
||||
- **Final Answer**: The correct, concise result.
|
||||
|
||||
2. **Ensure Logical Consistency**
|
||||
- The `rationale` must be code that runs correctly.
|
||||
- The `final_answer` should match the printed output.
|
||||
|
||||
3. **Output Format (Strict)**
|
||||
```
|
||||
Question: [Generated question]
|
||||
Rationale: [Code that solves the question, ending in a print statement,
|
||||
outputting the answer.]
|
||||
Final Answer: [The Final Answer]
|
||||
|
||||
**Now, generate a new data point based on the given examples.**
|
||||
"""
|
||||
|
||||
|
||||
class FewShotGenerator(BaseGenerator):
|
||||
r"""A generator for creating synthetic datapoints using few-shot learning.
|
||||
|
||||
This class leverages a seed dataset, an agent, and a verifier to generate
|
||||
new synthetic datapoints on demand through few-shot prompting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seed_dataset: StaticDataset,
|
||||
verifier: BaseVerifier,
|
||||
model: BaseModelBackend,
|
||||
seed: int = 42,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initialize the few-shot generator.
|
||||
|
||||
Args:
|
||||
seed_dataset (StaticDataset): Validated static dataset to
|
||||
use for examples.
|
||||
verifier (BaseVerifier): Verifier to validate generated content.
|
||||
model (BaseModelBackend): The underlying LLM that the generating
|
||||
agent will be initiated with.
|
||||
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
||||
**kwargs: Additional generator parameters.
|
||||
"""
|
||||
super().__init__(seed=seed, **kwargs)
|
||||
self.seed_dataset = seed_dataset
|
||||
try:
|
||||
self._validate_seed_dataset()
|
||||
except Exception:
|
||||
raise RuntimeError("Seed Data does not follow Datapoint format")
|
||||
self.verifier = verifier
|
||||
self.agent = ChatAgent(system_message=SYSTEM_PROMPT, model=model)
|
||||
|
||||
# TODO: Validate that seed dataset contains rationale
|
||||
def _validate_seed_dataset(self) -> None:
|
||||
pass
|
||||
|
||||
def _construct_prompt(self, examples: List[DataPoint]) -> str:
|
||||
r"""Construct a prompt for generating new datapoints
|
||||
using a fixed sample of examples from the seed dataset.
|
||||
|
||||
Args:
|
||||
examples (List[DataPoint]): Examples to include in the prompt.
|
||||
|
||||
Returns:
|
||||
str: Formatted prompt with examples.
|
||||
"""
|
||||
prompt = (
|
||||
"Generate a new datapoint similar to the following examples:\n\n"
|
||||
)
|
||||
for i, example in enumerate(examples, 1):
|
||||
prompt += f"Example {i}:\n"
|
||||
prompt += f"Question: {example.question}\n"
|
||||
if example.rationale is not None:
|
||||
prompt += f"Rationale: {example.rationale}\n"
|
||||
else:
|
||||
prompt += "Rationale: None\n"
|
||||
prompt += f"Final Answer: {example.final_answer}\n\n"
|
||||
prompt += "New datapoint:"
|
||||
return prompt
|
||||
|
||||
async def generate_new(
|
||||
self,
|
||||
n: int,
|
||||
max_retries: int = 10,
|
||||
num_examples: int = 3,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
r"""Generates and validates `n` new datapoints through
|
||||
few-shot prompting, with a retry limit.
|
||||
|
||||
Steps:
|
||||
1. Samples examples from the seed dataset.
|
||||
2. Constructs a prompt using the selected examples.
|
||||
3. Uses an agent to generate a new datapoint,
|
||||
consisting of a question and code to solve the question.
|
||||
4. Executes code using a verifier to get pseudo ground truth.
|
||||
5. Stores valid datapoints in memory.
|
||||
|
||||
Args:
|
||||
n (int): Number of valid datapoints to generate.
|
||||
max_retries (int): Maximum number of retries before stopping.
|
||||
(default: :obj:`10`)
|
||||
num_examples (int): Number of examples to sample from the
|
||||
seed dataset for few shot prompting.
|
||||
(default: :obj:`3`)
|
||||
**kwargs: Additional generation parameters.
|
||||
|
||||
Returns:
|
||||
List[DataPoint]: A list of newly generated valid datapoints.
|
||||
|
||||
Raises:
|
||||
TypeError: If the agent's output is not a dictionary (or does not
|
||||
match the expected format).
|
||||
KeyError: If required keys are missing from the response.
|
||||
AttributeError: If the verifier response lacks attributes.
|
||||
ValidationError: If a datapoint fails schema validation.
|
||||
RuntimeError: If retries are exhausted before `n` valid datapoints
|
||||
are generated.
|
||||
|
||||
Notes:
|
||||
- Retries on validation failures until `n` valid datapoints exist
|
||||
or `max_retries` is reached, whichever comes first.
|
||||
- If retries are exhausted before reaching `n`, a `RuntimeError`
|
||||
is raised.
|
||||
- Metadata includes a timestamp for tracking datapoint creation.
|
||||
"""
|
||||
valid_data_points: List[DataPoint] = []
|
||||
retries = 0
|
||||
|
||||
while len(valid_data_points) < n and retries < max_retries:
|
||||
try:
|
||||
examples = [
|
||||
self.seed_dataset.sample() for _ in range(num_examples)
|
||||
]
|
||||
prompt = self._construct_prompt(examples)
|
||||
|
||||
# Create a simplified version of DataPoint that omits metadata
|
||||
# because agent.step's response_format parameter doesn't
|
||||
# support type Dict[str, Any]
|
||||
class DataPointSimplified(BaseModel):
|
||||
question: str = Field(
|
||||
description="The primary question or issue to "
|
||||
"be addressed."
|
||||
)
|
||||
final_answer: str = Field(description="The final answer.")
|
||||
rationale: str = Field(
|
||||
description="Logical reasoning or explanation "
|
||||
"behind the answer."
|
||||
)
|
||||
|
||||
try:
|
||||
agent_output = (
|
||||
self.agent.step(
|
||||
prompt, response_format=DataPointSimplified
|
||||
)
|
||||
.msgs[0]
|
||||
.parsed
|
||||
)
|
||||
|
||||
assert isinstance(agent_output, DataPointSimplified)
|
||||
|
||||
self.agent.reset()
|
||||
|
||||
except (TypeError, KeyError) as e:
|
||||
logger.warning(
|
||||
f"Agent output issue: {e}, retrying... "
|
||||
f"({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
rationale = agent_output.rationale
|
||||
|
||||
if not isinstance(rationale, str):
|
||||
raise TypeError(f"Rationale {rationale} is not a string.")
|
||||
|
||||
try:
|
||||
verifier_response = await asyncio.wait_for(
|
||||
self.verifier.verify(
|
||||
solution=rationale,
|
||||
reference_answer=None,
|
||||
),
|
||||
timeout=180,
|
||||
)
|
||||
if not verifier_response or not verifier_response.result:
|
||||
raise ValueError(
|
||||
"Verifier unsuccessful, response: "
|
||||
f"{verifier_response}"
|
||||
)
|
||||
except (ValueError, AttributeError, asyncio.TimeoutError) as e:
|
||||
error_msg = (
|
||||
"Verifier timeout"
|
||||
if isinstance(e, asyncio.TimeoutError)
|
||||
else f"Verifier issue: {e}"
|
||||
)
|
||||
logger.warning(
|
||||
f"{error_msg}, retrying... "
|
||||
f"({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
new_datapoint = DataPoint(
|
||||
question=agent_output.question,
|
||||
rationale=rationale,
|
||||
final_answer=verifier_response.result,
|
||||
metadata={
|
||||
"synthetic": str(True),
|
||||
"created": datetime.now().isoformat(),
|
||||
"generator": "few_shot",
|
||||
"shots": [e.to_dict() for e in examples],
|
||||
},
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Datapoint validation failed: {e}, "
|
||||
f"retrying... ({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
valid_data_points.append(new_datapoint)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error: {e}, retrying..."
|
||||
f" ({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
|
||||
if len(valid_data_points) < n:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate {n} valid datapoints "
|
||||
f"after {max_retries} retries."
|
||||
)
|
||||
|
||||
# Thread-safe way to extend the data list
|
||||
async with asyncio.Lock():
|
||||
self._data.extend(valid_data_points)
|
||||
61
camel/datasets/models.py
Normal file
61
camel/datasets/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DataPoint(BaseModel):
|
||||
r"""A single data point in the dataset.
|
||||
|
||||
Attributes:
|
||||
question (str): The primary question or issue to be addressed.
|
||||
final_answer (str): The final answer.
|
||||
rationale (Optional[str]): Logical reasoning or explanation behind the
|
||||
answer. (default: :obj:`None`)
|
||||
metadata (Optional[Dict[str, Any]]): Additional metadata about the data
|
||||
point. (default: :obj:`None`)
|
||||
"""
|
||||
|
||||
question: str = Field(
|
||||
..., description="The primary question or issue to be addressed."
|
||||
)
|
||||
final_answer: str = Field(..., description="The final answer.")
|
||||
rationale: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Logical reasoning or explanation behind the answer.",
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Additional metadata about the data point."
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
r"""Convert DataPoint to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the DataPoint.
|
||||
"""
|
||||
return self.dict()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
|
||||
r"""Create a DataPoint from a dictionary.
|
||||
|
||||
Args:
|
||||
data (Dict[str, Any]): Dictionary containing DataPoint fields.
|
||||
|
||||
Returns:
|
||||
DataPoint: New DataPoint instance.
|
||||
"""
|
||||
return cls(**data)
|
||||
415
camel/datasets/self_instruct_generator.py
Normal file
415
camel/datasets/self_instruct_generator.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from camel.agents import ChatAgent
|
||||
from camel.logger import get_logger
|
||||
from camel.models import ModelFactory
|
||||
from camel.types import ModelPlatformType, ModelType
|
||||
from camel.verifiers import BaseVerifier
|
||||
|
||||
from .base_generator import BaseGenerator
|
||||
from .models import DataPoint
|
||||
from .static_dataset import StaticDataset
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_INSTRUCTION_SYSTEM_PROMPT = """
|
||||
You are a high-capacity instruction generation assistant.
|
||||
|
||||
Your task is to generate a **new, creative, and challenging question** based on
|
||||
several examples.
|
||||
These examples may cover different domains or styles, but your goal is to:
|
||||
- **Understand their specific patterns** in structure, and complexity;
|
||||
- **Combine and synthesize** ideas from multiple examples, rather than copying
|
||||
or lightly editing any single one;
|
||||
- **Intelligently integrate** multiple reasoning steps, constraints, or
|
||||
concepts into a single, coherent question;
|
||||
- Ensure the new question is **non-trivial** and requires deep thinking or
|
||||
multi-step reasoning.
|
||||
|
||||
**Guidelines:**
|
||||
- Use the examples as inspiration for format, depth, and tone.
|
||||
- Your new question should be self-contained, logically sound, and answerable.
|
||||
- Do not repeat exact phrasings or create shallow combinations; instead,
|
||||
produce something meaningfully new.
|
||||
- Avoid open-ended or subjective questions that depend on personal opinions or
|
||||
discussion.
|
||||
- The generated question must have a **clear, objective, and verifiable
|
||||
answer**.
|
||||
- Aim for increased depth or novelty through subtle combination or
|
||||
transformation.
|
||||
- Keep the final output to a **single unified question** with one clear answer,
|
||||
not a multi-part task.
|
||||
|
||||
**Output Format (strict):**
|
||||
```
|
||||
Question: [Generated question]
|
||||
```
|
||||
"""
|
||||
|
||||
DEFAULT_RATIONALE_SYSTEM_PROMPT = """You are an advanced Python code assistant.
|
||||
|
||||
Your task is to **solve the given question by writing Python code only**,
|
||||
without any explanation or natural language output.
|
||||
The code must compute the answer **programmatically**, not by hardcoding or
|
||||
guessing the result.
|
||||
|
||||
**Rules:**
|
||||
- Use Python code to perform the actual computation.
|
||||
- Use {package_list} to solve the problem. Do not import any other libraries.
|
||||
- **Do not hardcode the final answer** (e.g., avoid writing `print(1/2)` unless
|
||||
that value is computed).
|
||||
- The result must be obtained through valid computation logic in code.
|
||||
- Do not include explanations. Output code only.
|
||||
- The entire code must be wrapped in triple backticks:
|
||||
```
|
||||
[Your Python code here]
|
||||
```
|
||||
|
||||
Now, solve the following question using Python. Only output the code:
|
||||
"""
|
||||
|
||||
|
||||
class SelfInstructGenerator(BaseGenerator):
|
||||
r"""A generator for creating synthetic datapoints using self-instruct.
|
||||
|
||||
It utilizes both a human-provided dataset (seed_dataset) and generated
|
||||
machine instructions (machine_instructions) to produce new, synthetic
|
||||
datapoints that include a question, a computed rationale (code), and a
|
||||
final answer (from a verifier).
|
||||
"""
|
||||
|
||||
class QuestionSchema(BaseModel):
|
||||
r"""Schema for the generated question.
|
||||
|
||||
Attributes:
|
||||
question (str): The question generated by the model.
|
||||
"""
|
||||
|
||||
question: str = Field(description="The question generated")
|
||||
|
||||
class RationaleSchema(BaseModel):
|
||||
r"""Schema for the generated rationale code.
|
||||
|
||||
Attributes:
|
||||
code (str): The generated code without any formatting.
|
||||
"""
|
||||
|
||||
code: str = Field(
|
||||
description="The generated code without any formatting"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seed_dataset: StaticDataset,
|
||||
verifier: BaseVerifier,
|
||||
instruction_agent: Optional[ChatAgent] = None,
|
||||
rationale_agent: Optional[ChatAgent] = None,
|
||||
seed: int = 42,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initialize the self-instruct generator.
|
||||
|
||||
Args:
|
||||
seed_dataset (StaticDataset): Dataset containing seed instructions.
|
||||
verifier (BaseVerifier): Verifier instance to validate generated
|
||||
solutions.
|
||||
instruction_agent (Optional[ChatAgent]): Agent for generating
|
||||
instructions. If not provided, a default agent will be created.
|
||||
rationale_agent (Optional[ChatAgent]): Agent for generating
|
||||
rationales. If not provided, a default agent will be created.
|
||||
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
||||
**kwargs: Additional keyword arguments passed to the BaseGenerator.
|
||||
"""
|
||||
super().__init__(seed=seed, **kwargs)
|
||||
self.seed_dataset = seed_dataset
|
||||
self.verifier = verifier
|
||||
# extract packages from verifier
|
||||
self.packages: List[str] = getattr(
|
||||
self.verifier, "required_packages", []
|
||||
)
|
||||
# create default agents if not provided
|
||||
self.instruction_agent = (
|
||||
instruction_agent or self.default_instruction_agent()
|
||||
)
|
||||
self.rationale_agent = (
|
||||
rationale_agent or self.default_rationale_agent()
|
||||
)
|
||||
|
||||
# Extract questions from the seed dataset as human_instructions
|
||||
self.human_instructions: List[str] = [
|
||||
dp.question
|
||||
for dp in list(cast(Iterable[DataPoint], self.seed_dataset))
|
||||
]
|
||||
self.machine_instructions: List[DataPoint] = []
|
||||
# Create an instance-level lock for thread-safe updates to _data
|
||||
self._lock = asyncio.Lock()
|
||||
self._data = [] # Storage for generated DataPoint instances
|
||||
|
||||
def default_instruction_agent(self) -> ChatAgent:
|
||||
r"""Create the default instruction generation agent.
|
||||
|
||||
This agent is configured with a moderate temperature setting to
|
||||
encourage creative and diverse instruction generation behavior.
|
||||
|
||||
Returns:
|
||||
ChatAgent: An agent with the default instruction prompt.
|
||||
"""
|
||||
model = ModelFactory.create(
|
||||
model_platform=ModelPlatformType.DEFAULT,
|
||||
model_type=ModelType.DEFAULT,
|
||||
model_config_dict={"temperature": 0.7},
|
||||
)
|
||||
return ChatAgent(
|
||||
DEFAULT_INSTRUCTION_SYSTEM_PROMPT,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def default_rationale_agent(self) -> ChatAgent:
|
||||
r"""Create the default rationale generation agent.
|
||||
|
||||
This agent is configured with a deterministic (zero temperature)
|
||||
setting to ensure consistent and precise rationale generation based on
|
||||
a given instruction and package list.
|
||||
|
||||
Returns:
|
||||
ChatAgent: An agent with the rationale prompt
|
||||
"""
|
||||
model = ModelFactory.create(
|
||||
model_platform=ModelPlatformType.DEFAULT,
|
||||
model_type=ModelType.DEFAULT,
|
||||
model_config_dict={"temperature": 0.0},
|
||||
)
|
||||
return ChatAgent(
|
||||
DEFAULT_RATIONALE_SYSTEM_PROMPT.format(package_list=self.packages),
|
||||
model=model,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_support_block(dp: DataPoint) -> str:
|
||||
r"""Format a DataPoint into a few-shot example block.
|
||||
|
||||
Args:
|
||||
dp (DataPoint): A data point.
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing the question and its
|
||||
corresponding code block in Markdown-style Python format.
|
||||
"""
|
||||
support_q = dp.question.strip()
|
||||
support_code = dp.rationale.strip() if dp.rationale else ""
|
||||
return (
|
||||
f"Question:\n{support_q}\n\n"
|
||||
"Code:\n"
|
||||
"```python\n"
|
||||
f"{support_code}\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
def generate_new_instruction(
|
||||
self,
|
||||
agent: ChatAgent,
|
||||
support_human_dps: list[DataPoint],
|
||||
support_machine_dps: list[DataPoint],
|
||||
) -> str:
|
||||
r"""Generate a new instruction using self-instruct prompting.
|
||||
|
||||
Args:
|
||||
agent (ChatAgent): The agent to use for generating the instruction.
|
||||
support_human_dps (list[DataPoint]): List of human examples to
|
||||
sample.
|
||||
support_machine_dps (list[DataPoint]): List of machine examples to
|
||||
sample.
|
||||
|
||||
Returns:
|
||||
str: The newly generated question.
|
||||
"""
|
||||
human_sample = [dp.question for dp in list(support_human_dps)]
|
||||
machine_sample = [dp.question for dp in list(support_machine_dps)]
|
||||
|
||||
few_shot_examples = human_sample + machine_sample
|
||||
|
||||
# Build the prompt using the few-shot examples
|
||||
prompt = "Below are some question examples:\n\n"
|
||||
for idx, instr in enumerate(few_shot_examples, start=1):
|
||||
prompt += f"Question {idx}: {instr}\n"
|
||||
prompt += f"Question {len(few_shot_examples) + 1}:\n"
|
||||
prompt += "Now generate a new question based on the given examples.\n"
|
||||
|
||||
question_template = f"Question: {prompt}"
|
||||
response = cast(
|
||||
SelfInstructGenerator.QuestionSchema,
|
||||
agent.step(question_template, response_format=self.QuestionSchema)
|
||||
.msgs[0]
|
||||
.parsed,
|
||||
)
|
||||
return response.question
|
||||
|
||||
def generate_rationale(
|
||||
self,
|
||||
question: str,
|
||||
agent: Optional[ChatAgent] = None,
|
||||
support_human_dps: Optional[list[DataPoint]] = None,
|
||||
) -> str:
|
||||
r"""Generate rationale code (solution) for the given question.
|
||||
|
||||
Args:
|
||||
question (str): The question to be solved.
|
||||
agent (Optional[ChatAgent]): The agent to use for generating the
|
||||
rationale. If None is provided, the default rationale agent
|
||||
will be used. (default: :obj:`None`)
|
||||
support_human_dps (Optional[list[DataPoint]]): List of human
|
||||
examples to sample. (default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
str: The generated code solution as a string.
|
||||
"""
|
||||
|
||||
# Build few-shot example prompt
|
||||
few_shot_prompt = ""
|
||||
if support_human_dps:
|
||||
few_shot_examples = [
|
||||
self.format_support_block(dp) for dp in support_human_dps
|
||||
]
|
||||
few_shot_prompt += "Below are example questions and solutions:\n\n"
|
||||
few_shot_prompt += "\n\n".join(few_shot_examples)
|
||||
|
||||
few_shot_prompt += f"\n\nWrite code to solve the question:\n{question}"
|
||||
|
||||
response = cast(
|
||||
SelfInstructGenerator.RationaleSchema,
|
||||
(agent or self.default_rationale_agent())
|
||||
.step(few_shot_prompt, response_format=self.RationaleSchema)
|
||||
.msgs[0]
|
||||
.parsed,
|
||||
)
|
||||
return response.code
|
||||
|
||||
async def generate_new(
|
||||
self,
|
||||
n: int,
|
||||
max_retries: int = 10,
|
||||
human_sample_count: int = 3,
|
||||
machine_sample_count: int = 1,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
r"""Generates and validates `n` new datapoints through
|
||||
self-instruct prompting, with a retry limit.
|
||||
|
||||
Args:
|
||||
n (int): The number of valid datapoints to generate.
|
||||
max_retries (int): Maximum number of retries before stopping.
|
||||
(default: :obj:`10`)
|
||||
human_sample_count (int): Number of human examples to sample.
|
||||
(default: :obj:`3`)
|
||||
machine_sample_count (int): Number of machine examples to sample.
|
||||
(default: :obj:`1`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Notes:
|
||||
- Retries on validation failures until `n` valid datapoints exist
|
||||
or `max_retries` is reached, whichever comes first.
|
||||
- If retries are exhausted before reaching `n`, a `RuntimeError`
|
||||
is raised.
|
||||
- Metadata includes a timestamp for tracking datapoint creation.
|
||||
"""
|
||||
valid_data_points: list[DataPoint] = []
|
||||
retries = 0
|
||||
|
||||
while len(valid_data_points) < n and retries < max_retries:
|
||||
try:
|
||||
human_dps_list = list(cast(List[DataPoint], self.seed_dataset))
|
||||
support_human_dps = random.sample(
|
||||
human_dps_list,
|
||||
min(human_sample_count, len(human_dps_list)),
|
||||
)
|
||||
|
||||
machine_dps_list = list(self.machine_instructions)
|
||||
support_machine_dps = []
|
||||
if machine_dps_list and machine_sample_count > 0:
|
||||
support_machine_dps = random.sample(
|
||||
machine_dps_list,
|
||||
min(machine_sample_count, len(machine_dps_list)),
|
||||
)
|
||||
question = self.generate_new_instruction(
|
||||
self.instruction_agent,
|
||||
support_human_dps,
|
||||
support_machine_dps,
|
||||
)
|
||||
rationale = self.generate_rationale(
|
||||
question, self.rationale_agent, support_human_dps
|
||||
)
|
||||
if not isinstance(rationale, str):
|
||||
raise TypeError(f"Rationale {rationale} is not a string.")
|
||||
|
||||
try:
|
||||
verifier_response = await self.verifier.verify(
|
||||
solution=rationale,
|
||||
reference_answer=None,
|
||||
)
|
||||
if not verifier_response or not verifier_response.result:
|
||||
raise ValueError(
|
||||
"Verifier unsuccessful, response: "
|
||||
f"{verifier_response}"
|
||||
)
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Verifier issue: {e}, "
|
||||
f"retrying... ({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
try:
|
||||
new_datapoint = DataPoint(
|
||||
question=question,
|
||||
rationale=rationale,
|
||||
final_answer=verifier_response.result,
|
||||
metadata={
|
||||
"synthetic": str(True),
|
||||
"created": datetime.now().isoformat(),
|
||||
"generator": "self_instruct",
|
||||
},
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Datapoint validation failed: {e}, "
|
||||
f"retrying... ({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
valid_data_points.append(new_datapoint)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error: {e}, retrying..."
|
||||
f" ({retries + 1}/{max_retries})"
|
||||
)
|
||||
retries += 1
|
||||
|
||||
if len(valid_data_points) < n:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate {n} valid datapoints "
|
||||
f"after {max_retries} retries."
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._data.extend(valid_data_points)
|
||||
400
camel/datasets/static_dataset.py
Normal file
400
camel/datasets/static_dataset.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sized,
|
||||
Union,
|
||||
)
|
||||
|
||||
from datasets import Dataset as HFDataset
|
||||
from pydantic import ValidationError
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from camel.logger import get_logger
|
||||
|
||||
from .models import DataPoint
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StaticDataset(Dataset):
|
||||
r"""A static dataset containing a list of datapoints.
|
||||
Ensures that all items adhere to the DataPoint schema.
|
||||
This dataset extends :obj:`Dataset` from PyTorch and should
|
||||
be used when its size is fixed at runtime.
|
||||
|
||||
This class can initialize from Hugging Face Datasets,
|
||||
PyTorch Datasets, JSON file paths, or lists of dictionaries,
|
||||
converting them into a consistent internal format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
|
||||
seed: int = 42,
|
||||
min_samples: int = 1,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Initialize the static dataset and validate integrity.
|
||||
|
||||
Args:
|
||||
data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
|
||||
Input data, which can be one of the following:
|
||||
- A Hugging Face Dataset (:obj:`HFDataset`).
|
||||
- A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
|
||||
- A :obj:`Path` object representing a JSON or JSONL file.
|
||||
- A list of dictionaries with :obj:`DataPoint`-compatible
|
||||
fields.
|
||||
seed (int): Random seed for reproducibility.
|
||||
(default: :obj:`42`)
|
||||
min_samples (int): Minimum required number of samples.
|
||||
(default: :obj:`1`)
|
||||
strict (bool): Whether to raise an error on invalid
|
||||
datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
|
||||
(default: :obj:`False`)
|
||||
**kwargs: Additional dataset parameters.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input data type is unsupported.
|
||||
ValueError: If the dataset contains fewer than :obj:`min_samples`
|
||||
datapoints or if validation fails.
|
||||
FileNotFoundError: If the specified JSON file path does not exist.
|
||||
json.JSONDecodeError: If the JSON file contains invalid formatting.
|
||||
"""
|
||||
|
||||
# Store all parameters in metadata dict for compatibility
|
||||
self._metadata = {
|
||||
**kwargs,
|
||||
}
|
||||
self._rng = random.Random(seed)
|
||||
self._strict = strict
|
||||
|
||||
self.data: List[DataPoint] = self._init_data(data)
|
||||
self._length = len(self.data)
|
||||
|
||||
if self._length < min_samples:
|
||||
raise ValueError(
|
||||
"The dataset does not contain enough samples. "
|
||||
f"Need {max(0, min_samples)}, got {self._length}"
|
||||
)
|
||||
|
||||
def _init_data(
|
||||
self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
|
||||
) -> List[DataPoint]:
|
||||
r"""Convert input data from various formats into a list of
|
||||
:obj:`DataPoint` instances.
|
||||
|
||||
Args:
|
||||
data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]): Input
|
||||
dataset in one of the supported formats.
|
||||
|
||||
Returns:
|
||||
List[DataPoint]: A list of validated :obj:`DataPoint`
|
||||
instances.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input data type is unsupported.
|
||||
ValueError: If the Path has an unsupported file extension.
|
||||
"""
|
||||
|
||||
if isinstance(data, HFDataset):
|
||||
raw_data = self._init_from_hf_dataset(data)
|
||||
elif isinstance(data, Dataset):
|
||||
raw_data = self._init_from_pytorch_dataset(data)
|
||||
elif isinstance(data, Path):
|
||||
if data.suffix == ".jsonl":
|
||||
raw_data = self._init_from_jsonl_path(data)
|
||||
elif data.suffix == ".json":
|
||||
raw_data = self._init_from_json_path(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {data.suffix}."
|
||||
" Please enter a .json or .jsonl object."
|
||||
)
|
||||
|
||||
elif isinstance(data, list):
|
||||
raw_data = self._init_from_list(data)
|
||||
else:
|
||||
raise TypeError("Unsupported data type")
|
||||
|
||||
def create_datapoint(
|
||||
item: Dict[str, Any], idx: int
|
||||
) -> Optional[DataPoint]:
|
||||
# Add type checks for required fields to make mypy happy
|
||||
question = item.get('question')
|
||||
if not isinstance(question, str):
|
||||
if self._strict:
|
||||
raise ValueError(
|
||||
f"Sample at index {idx} has invalid 'question': "
|
||||
f"expected str, got {type(question)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping sample at index {idx}: invalid 'question'"
|
||||
)
|
||||
return None
|
||||
|
||||
rationale = item.get('rationale')
|
||||
|
||||
final_answer = item.get('final_answer')
|
||||
if not isinstance(final_answer, str):
|
||||
if self._strict:
|
||||
raise ValueError(
|
||||
f"Sample at index {idx} has invalid 'final_answer': "
|
||||
f"expected str, got {type(final_answer)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping sample at index {idx}: "
|
||||
"invalid 'final_answer'"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return DataPoint(
|
||||
question=question,
|
||||
rationale=rationale,
|
||||
final_answer=final_answer,
|
||||
metadata=item.get('metadata'),
|
||||
)
|
||||
except ValidationError as e:
|
||||
if self._strict:
|
||||
raise ValueError(
|
||||
f"Sample at index {idx} validation error: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping invalid sample at index {idx} "
|
||||
f"due to validation error: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
unfiltered_data = [
|
||||
create_datapoint(item, i) for i, item in enumerate(raw_data)
|
||||
]
|
||||
return [dp for dp in unfiltered_data if dp is not None]
|
||||
|
||||
def __len__(self) -> int:
|
||||
r"""Return the size of the dataset."""
|
||||
return self._length
|
||||
|
||||
def __getitem__(
|
||||
self, idx: Union[int, slice]
|
||||
) -> Union[DataPoint, List[DataPoint]]:
|
||||
r"""Retrieve a datapoint or a batch of datapoints by index or slice.
|
||||
|
||||
Args:
|
||||
idx (Union[int, slice]): Index or slice of the datapoint(s).
|
||||
|
||||
Returns:
|
||||
List[DataPoint]: A list of `DataPoint` objects.
|
||||
|
||||
Raises:
|
||||
IndexError: If an integer `idx` is out of bounds.
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
if idx < 0 or idx >= self._length:
|
||||
raise IndexError(
|
||||
f"Index {idx} out of bounds for dataset "
|
||||
f"of size {self._length}"
|
||||
)
|
||||
return self.data[idx]
|
||||
|
||||
elif isinstance(idx, slice):
|
||||
return self.data[idx.start : idx.stop : idx.step]
|
||||
|
||||
else:
|
||||
raise TypeError(f"Indexing type {type(idx)} not supported.")
|
||||
|
||||
def sample(self) -> DataPoint:
|
||||
r"""Sample a random datapoint from the dataset.
|
||||
|
||||
Returns:
|
||||
DataPoint: A randomly sampled :obj:`DataPoint`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the dataset is empty and no samples can be drawn.
|
||||
"""
|
||||
|
||||
if self._length == 0:
|
||||
raise RuntimeError("Dataset is empty, cannot sample.")
|
||||
idx = self._rng.randint(0, self._length - 1)
|
||||
sample = self[idx]
|
||||
if not isinstance(sample, DataPoint):
|
||||
raise TypeError(
|
||||
f"Expected DataPoint instance, got {type(sample).__name__}"
|
||||
)
|
||||
return sample
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
r"""Retrieve dataset metadata.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A copy of the dataset metadata dictionary.
|
||||
"""
|
||||
|
||||
return self._metadata.copy()
|
||||
|
||||
def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
|
||||
r"""Convert a Hugging Face dataset into a list of dictionaries.
|
||||
|
||||
Args:
|
||||
data (HFDataset): A Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries representing
|
||||
the dataset, where each dictionary corresponds to a datapoint.
|
||||
"""
|
||||
return [dict(item) for item in data]
|
||||
|
||||
def _init_from_pytorch_dataset(
|
||||
self, data: Dataset
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Convert a PyTorch dataset into a list of dictionaries.
|
||||
|
||||
Args:
|
||||
data (Dataset): A PyTorch dataset.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries representing
|
||||
the dataset.
|
||||
|
||||
Raises:
|
||||
TypeError: If the dataset does not implement :obj:`__len__()`
|
||||
or contains non-dictionary elements.
|
||||
"""
|
||||
if not isinstance(data, Sized):
|
||||
raise TypeError(
|
||||
f"{type(data).__name__} does not implement `__len__()`."
|
||||
)
|
||||
raw_data = []
|
||||
|
||||
for i in range(len(data)):
|
||||
item = data[i]
|
||||
if not isinstance(item, dict):
|
||||
raise TypeError(
|
||||
f"Item at index {i} is not a dict: "
|
||||
f"got {type(item).__name__}"
|
||||
)
|
||||
raw_data.append(dict(item))
|
||||
return raw_data
|
||||
|
||||
def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
|
||||
r"""Load and parse a dataset from a JSON file.
|
||||
|
||||
Args:
|
||||
data (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of datapoint dictionaries.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified JSON file does not exist.
|
||||
ValueError: If the JSON content is not a list of dictionaries.
|
||||
json.JSONDecodeError: If the JSON file has invalid formatting.
|
||||
"""
|
||||
|
||||
if not data.exists():
|
||||
raise FileNotFoundError(f"JSON file not found: {data}")
|
||||
try:
|
||||
logger.debug(f"Loading JSON from {data}")
|
||||
with data.open('r', encoding='utf-8') as f:
|
||||
loaded_data = json.load(f)
|
||||
logger.info(
|
||||
f"Successfully loaded {len(loaded_data)} items from {data}"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in file {data}: {e}")
|
||||
if not isinstance(loaded_data, list):
|
||||
raise ValueError("JSON file must contain a list of dictionaries")
|
||||
for i, item in enumerate(loaded_data):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(
|
||||
f"Expected a dictionary at index {i}, "
|
||||
f"got {type(item).__name__}"
|
||||
)
|
||||
return loaded_data
|
||||
|
||||
def _init_from_jsonl_path(self, data: Path) -> List[Dict[str, Any]]:
|
||||
r"""Load and parse a dataset from a JSONL file.
|
||||
|
||||
Args:
|
||||
data (Path): Path to the JSONL file.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of datapoint dictionaries.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified JSONL file does not exist.
|
||||
ValueError: If a line in the file contains invalid JSON or
|
||||
is not a dictionary.
|
||||
"""
|
||||
if not data.exists():
|
||||
raise FileNotFoundError(f"JSONL file not found: {data}")
|
||||
|
||||
raw_data = []
|
||||
logger.debug(f"Loading JSONL from {data}")
|
||||
with data.open('r', encoding='utf-8') as f:
|
||||
for line_number, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue # Skip blank lines if any exist.
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid JSON on line {line_number} in file "
|
||||
f"{data}: {e}"
|
||||
)
|
||||
raw_data.append(record)
|
||||
logger.info(f"Successfully loaded {len(raw_data)} items from {data}")
|
||||
|
||||
for i, item in enumerate(raw_data):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(
|
||||
f"Expected a dictionary at record {i+1} (line {i+1}), "
|
||||
f"got {type(item).__name__}"
|
||||
)
|
||||
return raw_data
|
||||
|
||||
def _init_from_list(
|
||||
self, data: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
r"""Validate and convert a list of dictionaries into a dataset.
|
||||
|
||||
Args:
|
||||
data (List[Dict[str, Any]]): A list of dictionaries where
|
||||
each dictionary must be a valid :obj:`DataPoint`.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The validated list of dictionaries.
|
||||
|
||||
Raises:
|
||||
ValueError: If any item in the list is not a dictionary.
|
||||
"""
|
||||
for i, item in enumerate(data):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(
|
||||
f"Expected a dictionary at index {i}, "
|
||||
f"got {type(item).__name__}"
|
||||
)
|
||||
return data
|
||||
34
camel/embeddings/__init__.py
Normal file
34
camel/embeddings/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
from .azure_embedding import AzureEmbedding
|
||||
from .base import BaseEmbedding
|
||||
from .jina_embedding import JinaEmbedding
|
||||
from .mistral_embedding import MistralEmbedding
|
||||
from .openai_compatible_embedding import OpenAICompatibleEmbedding
|
||||
from .openai_embedding import OpenAIEmbedding
|
||||
from .sentence_transformers_embeddings import SentenceTransformerEncoder
|
||||
from .together_embedding import TogetherEmbedding
|
||||
from .vlm_embedding import VisionLanguageEmbedding
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbedding",
|
||||
"OpenAIEmbedding",
|
||||
"AzureEmbedding",
|
||||
"SentenceTransformerEncoder",
|
||||
"VisionLanguageEmbedding",
|
||||
"MistralEmbedding",
|
||||
"OpenAICompatibleEmbedding",
|
||||
"JinaEmbedding",
|
||||
"TogetherEmbedding",
|
||||
]
|
||||
119
camel/embeddings/azure_embedding.py
Normal file
119
camel/embeddings/azure_embedding.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from camel.embeddings.base import BaseEmbedding
|
||||
from camel.types import EmbeddingModelType
|
||||
from camel.utils import api_keys_required # Add this import
|
||||
|
||||
|
||||
class AzureEmbedding(BaseEmbedding[str]):
|
||||
r"""Provides text embedding functionalities using Azure's OpenAI models.
|
||||
|
||||
Args:
|
||||
model_type (EmbeddingModelType, optional): The model type to be
|
||||
used for text embeddings.
|
||||
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
|
||||
url (Optional[str], optional): The url to the Azure OpenAI service.
|
||||
(default: :obj:`None`)
|
||||
api_key (str, optional): The API key for authenticating with the
|
||||
Azure OpenAI service. (default: :obj:`None`)
|
||||
api_version (str, optional): The API version for Azure OpenAI service.
|
||||
(default: :obj:`None`)
|
||||
dimensions (Optional[int], optional): The text embedding output
|
||||
dimensions. (default: :obj:`None`)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an unsupported model type is specified.
|
||||
ValueError: If required API configuration is missing.
|
||||
"""
|
||||
|
||||
@api_keys_required(
|
||||
[
|
||||
("api_key", 'AZURE_OPENAI_API_KEY'),
|
||||
("url", 'AZURE_OPENAI_BASE_URL'),
|
||||
]
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model_type: EmbeddingModelType = (
|
||||
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
|
||||
),
|
||||
url: Union[str, None] = None,
|
||||
api_key: Union[str, None] = None,
|
||||
api_version: Union[str, None] = None,
|
||||
dimensions: Union[int, None] = None,
|
||||
) -> None:
|
||||
self.model_type = model_type
|
||||
self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
|
||||
if dimensions is None:
|
||||
self.output_dim = model_type.output_dim
|
||||
else:
|
||||
if not isinstance(dimensions, int):
|
||||
raise ValueError("dimensions must be an integer")
|
||||
self.output_dim = dimensions
|
||||
|
||||
self._api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
self._url = url or os.environ.get("AZURE_OPENAI_BASE_URL")
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_key=self._api_key,
|
||||
api_version=self.api_version,
|
||||
azure_endpoint=str(self._url),
|
||||
)
|
||||
|
||||
def embed_list(
|
||||
self,
|
||||
objs: list[str],
|
||||
**kwargs: Any,
|
||||
) -> list[list[float]]:
|
||||
r"""Embeds a list of texts using the Azure OpenAI model.
|
||||
|
||||
Args:
|
||||
objs (list[str]): The list of texts to embed.
|
||||
**kwargs (Any): Additional keyword arguments to pass to the API.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: The embeddings for the input texts.
|
||||
"""
|
||||
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=objs,
|
||||
model=self.model_type.value,
|
||||
dimensions=self.output_dim,
|
||||
**kwargs,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
r"""Returns the output dimension of the embeddings.
|
||||
|
||||
Returns:
|
||||
int: The dimensionality of the embedding for the current model.
|
||||
"""
|
||||
return self.output_dim
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user