initial update workforce code for gaia

This commit is contained in:
Yuhang Zhou
2025-05-16 22:28:11 +08:00
commit bc6d952d84
400 changed files with 78717 additions and 0 deletions

41
.env_example Normal file
View File

@@ -0,0 +1,41 @@
# To use these environment variables:
# 1. Populate the .env file with your API keys.
# 2. Include the following code snippet in your Python script:
# from dotenv import load_dotenv
# import os
#
# load_dotenv() # Load environment variables from .env file
#===========================================
# Models API
#===========================================
# OpenAI API (https://platform.openai.com/signup)
OPENAI_API_KEY="Fill your API key here"
# Anthropic API (https://www.anthropic.com/)
ANTHROPIC_API_KEY="Fill your API key here"
# Hugging Face API (https://huggingface.co/join)
HF_TOKEN="Fill your API key here"
# Azure OpenAI API (https://azure.microsoft.com/products/cognitive-services/openai-service/)
AZURE_OPENAI_API_KEY="Fill your API key here"
AZURE_API_VERSION="Fill your API Version here"
AZURE_DEPLOYMENT_NAME="Fill your Deployment Name here"
AZURE_OPENAI_BASE_URL="Fill your Base URL here"
#===========================================
# Tools & Services API
#===========================================
# Google Search API (https://developers.google.com/custom-search/v1/overview)
GOOGLE_API_KEY="Fill your API key here"
SEARCH_ENGINE_ID="Fill your Search Engine ID here"
# Firecrawl API (https://www.firecrawl.dev/)
FIRECRAWL_API_KEY="Fill your API key here"
# Chunkr API (https://chunkr.ai/)
CHUNKR_API_KEY="Fill your API key here"

63
.gitignore vendored Normal file
View File

@@ -0,0 +1,63 @@
# Python
__pycache__/
**/__pycache__/
*/__pycache__/*
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
.dist
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
venv/
env/
ENV/
.env
# IDE
.idea/
.vscode/
*.swp
*.swo
.DS_Store
# Project specific
data/gaia
tmp
.env
utils/__pycache__/
# Logs
*.log
logs/
log/
# Coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
coverage.xml
*.cover
camel/types/__pycache__/
camel/__pycache__/
camel/utils/__pycache_/
data/*

25
camel/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.logger import disable_logging, enable_logging, set_log_level
__version__ = '0.2.47'
__all__ = [
'__version__',
'camel',
'disable_logging',
'enable_logging',
'set_log_level',
]

46
camel/agents/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseAgent
from .chat_agent import ChatAgent
from .critic_agent import CriticAgent
from .embodied_agent import EmbodiedAgent
from .knowledge_graph_agent import KnowledgeGraphAgent
from .repo_agent import RepoAgent
from .role_assignment_agent import RoleAssignmentAgent
from .search_agent import SearchAgent
from .task_agent import (
TaskCreationAgent,
TaskPlannerAgent,
TaskPrioritizationAgent,
TaskSpecifyAgent,
)
from .tool_agents.base import BaseToolAgent
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseAgent',
'ChatAgent',
'TaskSpecifyAgent',
'TaskPlannerAgent',
'TaskCreationAgent',
'TaskPrioritizationAgent',
'CriticAgent',
'BaseToolAgent',
'HuggingFaceToolAgent',
'EmbodiedAgent',
'RoleAssignmentAgent',
'SearchAgent',
'KnowledgeGraphAgent',
'RepoAgent',
]

41
camel/agents/_types.py Normal file
View File

@@ -0,0 +1,41 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union
from openai import AsyncStream, Stream
from pydantic import BaseModel, ConfigDict
from camel.messages import BaseMessage
from camel.types import ChatCompletion
class ToolCallRequest(BaseModel):
r"""The request for tool calling."""
tool_name: str
args: Dict[str, Any]
tool_call_id: str
class ModelResponse(BaseModel):
r"""The response from the model."""
model_config = ConfigDict(arbitrary_types_allowed=True)
response: Union[ChatCompletion, Stream, AsyncStream]
tool_call_requests: Optional[List[ToolCallRequest]]
output_messages: List[BaseMessage]
finish_reasons: List[str]
usage_dict: Dict[str, Any]
response_id: str

188
camel/agents/_utils.py Normal file
View File

@@ -0,0 +1,188 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import re
import textwrap
from typing import Any, Callable, Dict, List, Optional, Union
from camel.agents._types import ToolCallRequest
from camel.toolkits import FunctionTool
from camel.types import Choice
from camel.types.agents import ToolCallingRecord
logger = logging.getLogger(__name__)
def generate_tool_prompt(tool_schema_list: List[Dict[str, Any]]) -> str:
r"""Generates a tool prompt based on the provided tool schema list.
Returns:
str: A string representing the tool prompt.
"""
tool_prompts = []
for tool in tool_schema_list:
tool_info = tool["function"]
tool_name = tool_info["name"]
tool_description = tool_info["description"]
tool_json = json.dumps(tool_info, indent=4, ensure_ascii=False)
prompt = (
f"Use the function '{tool_name}' to '{tool_description}':\n"
f"{tool_json}\n"
)
tool_prompts.append(prompt)
tool_prompt_str = "\n".join(tool_prompts)
final_prompt = textwrap.dedent(
f"""\
You have access to the following functions:
{tool_prompt_str}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls.
""" # noqa: E501
)
return final_prompt
def extract_tool_call(
content: str,
) -> Optional[Dict[str, Any]]:
r"""Extract the tool call from the model response, if present.
Args:
response (Any): The model's response object.
Returns:
Optional[Dict[str, Any]]: The parsed tool call if present,
otherwise None.
"""
function_regex = r"<function=(\w+)>(.*?)</function>"
match = re.search(function_regex, content)
if not match:
return None
function_name, args_string = match.groups()
try:
args = json.loads(args_string)
return {"function": function_name, "arguments": args}
except json.JSONDecodeError as error:
logger.error(f"Error parsing function arguments: {error}")
return None
def safe_model_dump(obj) -> Dict[str, Any]:
r"""Safely dump a Pydantic model to a dictionary.
This method attempts to use the `model_dump` method if available,
otherwise it falls back to the `dict` method.
"""
# Check if the `model_dump` method exists (Pydantic v2)
if hasattr(obj, "model_dump"):
return obj.model_dump()
# Fallback to `dict()` method (Pydantic v1)
elif hasattr(obj, "dict"):
return obj.dict()
else:
raise TypeError("The object is not a Pydantic model")
def convert_to_function_tool(
tool: Union[FunctionTool, Callable],
) -> FunctionTool:
r"""Convert a tool to a FunctionTool from Callable."""
return tool if isinstance(tool, FunctionTool) else FunctionTool(tool)
def convert_to_schema(
tool: Union[FunctionTool, Callable, Dict[str, Any]],
) -> Dict[str, Any]:
r"""Convert a tool to a schema from Callable or FunctionTool."""
if isinstance(tool, FunctionTool):
return tool.get_openai_tool_schema()
elif callable(tool):
return FunctionTool(tool).get_openai_tool_schema()
else:
return tool
def get_info_dict(
session_id: Optional[str],
usage: Optional[Dict[str, int]],
termination_reasons: List[str],
num_tokens: int,
tool_calls: List[ToolCallingRecord],
external_tool_call_requests: Optional[List[ToolCallRequest]] = None,
) -> Dict[str, Any]:
r"""Returns a dictionary containing information about the chat session.
Args:
session_id (str, optional): The ID of the chat session.
usage (Dict[str, int], optional): Information about the usage of
the LLM.
termination_reasons (List[str]): The reasons for the termination
of the chat session.
num_tokens (int): The number of tokens used in the chat session.
tool_calls (List[ToolCallingRecord]): The list of function
calling records, containing the information of called tools.
external_tool_call_requests (Optional[List[ToolCallRequest]]): The
requests for external tool calls.
Returns:
Dict[str, Any]: The chat session information.
"""
return {
"id": session_id,
"usage": usage,
"termination_reasons": termination_reasons,
"num_tokens": num_tokens,
"tool_calls": tool_calls,
"external_tool_call_requests": external_tool_call_requests,
}
def handle_logprobs(choice: Choice) -> Optional[List[Dict[str, Any]]]:
if choice.logprobs is None:
return None
tokens_logprobs = choice.logprobs.content
if tokens_logprobs is None:
return None
return [
{
"token": token_logprob.token,
"logprob": token_logprob.logprob,
"top_logprobs": [
(top_logprob.token, top_logprob.logprob)
for top_logprob in token_logprob.top_logprobs
],
}
for token_logprob in tokens_logprobs
]

29
camel/agents/base.py Normal file
View File

@@ -0,0 +1,29 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
r"""An abstract base class for all CAMEL agents."""
@abstractmethod
def reset(self, *args: Any, **kwargs: Any) -> Any:
r"""Resets the agent to its initial state."""
pass
@abstractmethod
def step(self, *args: Any, **kwargs: Any) -> Any:
r"""Performs a single step of the agent."""
pass

1407
camel/agents/chat_agent.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,202 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
import warnings
from typing import Any, Dict, Optional, Sequence
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.memories import AgentMemory
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import get_first_int, print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="CriticAgent")
class CriticAgent(ChatAgent):
r"""A class for the critic agent that assists in selecting an option.
Args:
system_message (BaseMessage): The system message for the critic
agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`6`)
retry_attempts (int, optional): The number of retry attempts if the
critic fails to return a valid option. (default: :obj:`2`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the menu options displayed to the
user. (default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
memory: Optional[AgentMemory] = None,
message_window_size: int = 6,
retry_attempts: int = 2,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
super().__init__(
system_message,
model=model,
memory=memory,
message_window_size=message_window_size,
)
self.options_dict: Dict[str, str] = dict()
self.retry_attempts = retry_attempts
self.verbose = verbose
self.logger_color = logger_color
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
r"""Flattens the options to the critic.
Args:
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
Returns:
str: A string containing the flattened options to the critic.
"""
options = [message.content for message in messages]
flatten_options = (
f"> Proposals from "
f"{messages[0].role_name} ({messages[0].role_type}). "
"Please choose an option:\n"
)
for index, option in enumerate(options):
flatten_options += f"Option {index + 1}:\n{option}\n\n"
self.options_dict[str(index + 1)] = option
format = (
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
"and then your explanation and comparison: "
)
return flatten_options + format
def get_option(self, input_message: BaseMessage) -> str:
r"""Gets the option selected by the critic.
Args:
input_message (BaseMessage): A `BaseMessage` object representing
the input message.
Returns:
str: The option selected by the critic.
"""
# TODO: Add support for editing options by the critic.
msg_content = input_message.content
i = 0
while i < self.retry_attempts:
critic_response = self.step(input_message)
if critic_response.msgs is None or len(critic_response.msgs) == 0:
raise RuntimeError("Got None critic messages.")
if critic_response.terminated:
raise RuntimeError("Critic step failed.")
critic_msg = critic_response.msg
if self.verbose:
print_text_animated(
self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
)
choice = self.parse_critic(critic_msg)
if choice in self.options_dict:
return self.options_dict[choice]
else:
input_message = BaseMessage(
role_name=input_message.role_name,
role_type=input_message.role_type,
meta_dict=input_message.meta_dict,
content="> Invalid choice. Please choose again.\n"
+ msg_content,
)
i += 1
warnings.warn(
"Critic failed to get a valid option. "
f"After {self.retry_attempts} attempts. "
"Returning a random option."
)
return random.choice(list(self.options_dict.values()))
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
r"""Parses the critic's message and extracts the choice.
Args:
critic_msg (BaseMessage): A `BaseMessage` object representing the
critic's response.
Returns:
Optional[str]: The critic's choice as a string, or None if the
message could not be parsed.
"""
choice = str(get_first_int(critic_msg.content))
return choice
def reduce_step(
self,
input_messages: Sequence[BaseMessage],
) -> ChatAgentResponse:
r"""Performs one step of the conversation by flattening options to the
critic, getting the option, and parsing the choice.
Args:
input_messages (Sequence[BaseMessage]): A list of BaseMessage
objects.
Returns:
ChatAgentResponse: A `ChatAgentResponse` object includes the
critic's choice.
"""
meta_chat_message = BaseMessage(
role_name=input_messages[0].role_name,
role_type=input_messages[0].role_type,
meta_dict=input_messages[0].meta_dict,
content="",
)
flatten_options = self.flatten_options(input_messages)
if self.verbose:
print_text_animated(
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
)
input_msg = meta_chat_message.create_new_instance(flatten_options)
option = self.get_option(input_msg)
output_msg = meta_chat_message.create_new_instance(option)
# TODO: The return `info` can be improved.
return ChatAgentResponse(
msgs=[output_msg],
terminated=False,
info={},
)

View File

@@ -0,0 +1,303 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.logger import get_logger
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
logger = get_logger(__name__)
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="DeductiveReasonerAgent")
class DeductiveReasonerAgent(ChatAgent):
r"""An agent responsible for deductive reasoning. Model of deductive
reasoning:
- L: A ⊕ C -> q * B
- A represents the known starting state.
- B represents the known target state.
- C represents the conditions required to transition from A to B.
- Q represents the quality or effectiveness of the transition from
A to B.
- L represents the path or process from A to B.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Insight Agent",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def deduce_conditions_and_quality(
self,
starting_state: str,
target_state: str,
role_descriptions_dict: Optional[Dict[str, str]] = None,
) -> Dict[str, Union[List[str], Dict[str, str]]]:
r"""Derives the conditions and quality from the starting state and the
target state based on the model of the deductive reasoning and the
knowledge base. It can optionally consider the roles involved in the
scenario, which allows tailoring the output more closely to the AI
agent's environment.
Args:
starting_state (str): The initial or starting state from which
conditions are deduced.
target_state (str): The target state of the task.
role_descriptions_dict (Optional[Dict[str, str]], optional): The
descriptions of the roles. (default: :obj:`None`)
role_descriptions_dict (Optional[Dict[str, str]], optional): A
dictionary describing the roles involved in the scenario. This
is optional and can be used to provide a context for the
CAMEL's role-playing, enabling the generation of more relevant
and tailored conditions and quality assessments. This could be
generated using a `RoleAssignmentAgent()` or defined manually
by the user.
Returns:
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
extracted data from the message. The dictionary contains three
keys:
- 'conditions': A list where each key is a condition ID and
each value is the corresponding condition text.
- 'labels': A list of label strings extracted from the message.
- 'quality': A string of quality assessment strings extracted
from the message.
"""
self.reset()
deduce_prompt = """You are a deductive reasoner. You are tasked to
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
CONTENT to help you complete the TASK.
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
fill in the BLANKs, and DO NOT alter or modify any other part of the template
===== MODELING OF DEDUCTIVE REASONING =====
You are tasked with understanding a mathematical model based on the components
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
- $A$ represents the known starting state.
- $B$ represents the known target state.
- $C$ represents the conditions required to transition from $A$ to $B$.
- $Q$ represents the quality or effectiveness of the transition from $A$ to
$B$.
- $L$ represents the path or process from $A$ to $B$.
===== THOUGHT OF DEDUCTIVE REASONING =====
1. Define the Parameters of A and B:
- Characterization: Before delving into transitions, thoroughly understand
the nature and boundaries of both $A$ and $B$. This includes the type,
properties, constraints, and possible interactions between the two.
- Contrast and Compare: Highlight the similarities and differences between
$A$ and $B$. This comparative analysis will give an insight into what
needs changing and what remains constant.
2. Historical & Empirical Analysis:
- Previous Transitions according to the Knowledge Base of GPT: (if
applicable) Extract conditions and patterns from the historical instances
where a similar transition from a state comparable to $A$ moved towards
$B$.
- Scientific Principles: (if applicable) Consider the underlying
scientific principles governing or related to the states and their
transition. For example, if $A$ and $B$ are physical states, laws of
physics might apply.
3. Logical Deduction of Conditions ($C$):
- Direct Path Analysis: What are the immediate and direct conditions
required to move from $A$ to $B$?
- Intermediate States: Are there states between $A$ and $B$ that must be
traversed or can be used to make the transition smoother or more
efficient? If yes, what is the content?
- Constraints & Limitations: Identify potential barriers or restrictions
in moving from $A$ to $B$. These can be external (e.g., environmental
factors) or internal (properties of $A$ or $B$).
- Resource and Information Analysis: What resources and information are
required for the transition? This could be time, entity, factor, code
language, software platform, unknowns, etc.
- External Influences: Consider socio-economic, political, or
environmental factors (if applicable) that could influence the transition
conditions.
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
no matter how unconventional they might seem. Utilize analogies,
metaphors, or brainstorming techniques to envision possible conditions or
paths from $A$ to $B$.
- The conditions $C$ should be multiple but in one sentence. And each
condition should be concerned with one aspect/entity.
4. Entity/Label Recognition of Conditions ($C$):
- Identify and categorize entities of Conditions ($C$) such as the names,
locations, dates, specific technical terms or contextual parameters that
might be associated with events, innovations post-2022.
- The output of the entities/labels will be used as tags or labels for
semantic similarity searches. The entities/labels may be the words, or
phrases, each of them should contain valuable, high information entropy
information, and should be independent.
- Ensure that the identified entities are formatted in a manner suitable
for database indexing and retrieval. Organize the entities into
categories, and combine the category with its instance into a continuous
phrase, without using colons or other separators.
- Format these entities for database indexing: output the category rather
than its instance/content into a continuous phrase. For example, instead
of "Jan. 02", identify it as "Event time".
5. Quality Assessment ($Q$):
- Efficiency: How efficient is the transition from $A$ to $B$, which
measures the resources used versus the desired outcome?
- Effectiveness: Did the transition achieve the desired outcome or was the
target state achieved as intended?
- Safety & Risks: Assess any risks associated with the transition and the
measures to mitigate them.
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
and adjust the quality of transition, making it more adaptive.
6. Iterative Evaluation:
- Test & Refine: Based on the initially deduced conditions and assessed
quality, iterate the process to refine and optimize the transition. This
might involve tweaking conditions, employing different paths, or changing
resources.
- Feedback Integration: Use feedback to make improvements and increase the
quality of the transition.
7. Real-world scenarios often present challenges that may not be captured by
models and frameworks. While using the model, maintain an adaptive mindset:
- Scenario Exploration: Continuously imagine various possible scenarios,
both positive and negative, to prepare for unexpected events.
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
process ($L$) if unforeseen challenges arise.
- Feedback Integration: Rapidly integrate feedback from actual
implementations to adjust the model's application, ensuring relevancy and
effectiveness.
===== TASK =====
Given the starting state $A$ and the target state $B$, assuming that a path
$L$ always exists between $A$ and $B$, how can one deduce or identify the
necessary conditions $C$ and the quality $Q$ of the transition?
===== STARTING STATE $A$ =====
{starting_state}
===== TARGET STATE $B$ =====
{target_state}
{role_with_description_prompt}
===== ANSWER TEMPLATE =====
- Characterization and comparison of $A$ and $B$:\n<BLANK>
- Historical & Empirical Analysis:\n<BLANK>/None
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
condition <NUM>:
<BLANK>.
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
square brackets)
- Quality Assessment ($Q$) (do not use symbols):
<BLANK>.
- Iterative Evaluation:\n<BLANK>/None"""
if role_descriptions_dict is not None:
role_names = role_descriptions_dict.keys()
role_with_description_prompt = (
"===== ROLES WITH DESCRIPTIONS =====\n"
+ "\n".join(
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
for role_name in role_names
)
+ "\n\n"
)
else:
role_with_description_prompt = ""
deduce_prompt = TextPrompt(deduce_prompt)
deduce = deduce_prompt.format(
starting_state=starting_state,
target_state=target_state,
role_with_description_prompt=role_with_description_prompt,
)
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
role_name="Deductive Reasoner", content=deduce
)
response = self.step(
input_message=conditions_and_quality_generation_msg
)
if response.terminated:
raise RuntimeError(
"Deduction failed. Error:\n" + f"{response.info}"
)
msg: BaseMessage = response.msg
logger.info(f"Message content:\n{msg.content}")
# Extract the conditions from the message
conditions_dict = {
f"condition {i}": cdt.replace("<", "")
.replace(">", "")
.strip()
.strip('\n')
for i, cdt in re.findall(
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
msg.content,
re.DOTALL,
)
}
# Extract the labels from the message
labels = [
label.strip().strip('\n').strip("\"'")
for label in re.findall(
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
msg.content,
re.DOTALL,
)[0].split(",")
]
# Extract the quality from the message
quality = next(
q.strip().strip('\n')
for q in re.findall(
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
r"\n(.+?)- Iterative",
msg.content,
re.DOTALL,
)
)
# Convert them into JSON format
conditions_and_quality_json: Dict[
str, Union[List[str], Dict[str, str]]
] = {}
conditions_and_quality_json["conditions"] = conditions_dict
conditions_and_quality_json["labels"] = labels
conditions_and_quality_json["evaluate_quality"] = quality
return conditions_and_quality_json

View File

@@ -0,0 +1,201 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, List, Optional
from colorama import Fore
from camel.agents.chat_agent import ChatAgent
from camel.agents.tool_agents.base import BaseToolAgent
from camel.interpreters import (
BaseInterpreter,
InternalPythonInterpreter,
SubprocessInterpreter,
)
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.responses import ChatAgentResponse
from camel.utils import print_text_animated
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="EmbodiedAgent")
class EmbodiedAgent(ChatAgent):
r"""Class for managing conversations of CAMEL Embodied Agents.
Args:
system_message (BaseMessage): The system message for the chat agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
the embodied agent. (default: :obj:`None`)
code_interpreter (BaseInterpreter, optional): The code interpreter to
execute codes. If `code_interpreter` and `tool_agent` are both
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
is `None` and `tool_agents` is not `None`, default to
`InternalPythonInterpreter`. (default: :obj:`None`)
verbose (bool, optional): Whether to print the critic's messages.
logger_color (Any): The color of the logger displayed to the user.
(default: :obj:`Fore.MAGENTA`)
"""
def __init__(
self,
system_message: BaseMessage,
model: Optional[BaseModelBackend] = None,
message_window_size: Optional[int] = None,
tool_agents: Optional[List[BaseToolAgent]] = None,
code_interpreter: Optional[BaseInterpreter] = None,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
self.tool_agents = tool_agents
self.code_interpreter: BaseInterpreter
if code_interpreter is not None:
self.code_interpreter = code_interpreter
elif self.tool_agents:
self.code_interpreter = InternalPythonInterpreter()
else:
self.code_interpreter = SubprocessInterpreter()
if self.tool_agents:
system_message = self._set_tool_agents(system_message)
self.verbose = verbose
self.logger_color = logger_color
super().__init__(
system_message=system_message,
model=model,
message_window_size=message_window_size,
)
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
action_space_prompt = self._get_tool_agents_prompt()
result_message = system_message.create_new_instance(
content=system_message.content.format(
action_space=action_space_prompt
)
)
if self.tool_agents is not None:
self.code_interpreter.update_action_space(
{tool.name: tool for tool in self.tool_agents}
)
return result_message
def _get_tool_agents_prompt(self) -> str:
r"""Returns the action space prompt.
Returns:
str: The action space prompt.
"""
if self.tool_agents is not None:
return "\n".join(
[
f"*** {tool.name} ***:\n {tool.description}"
for tool in self.tool_agents
]
)
else:
return ""
def get_tool_agent_names(self) -> List[str]:
r"""Returns the names of tool agents.
Returns:
List[str]: The names of tool agents.
"""
if self.tool_agents is not None:
return [tool.name for tool in self.tool_agents]
else:
return []
# ruff: noqa: E501
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
r"""Performs a step in the conversation.
Args:
input_message (BaseMessage): The input message.
Returns:
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
response = super().step(input_message)
if response.msgs is None or len(response.msgs) == 0:
raise RuntimeError("Got None output messages.")
if response.terminated:
raise RuntimeError(f"{self.__class__.__name__} step failed.")
# NOTE: Only single output messages are supported
explanations, codes = response.msg.extract_text_and_code_prompts()
if self.verbose:
for explanation, code in zip(explanations, codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanation}"
)
print_text_animated(self.logger_color + f"> Code:\n{code}")
if len(explanations) > len(codes):
print_text_animated(
self.logger_color + f"> Explanation:\n{explanations[-1]}"
)
content = response.msg.content
if codes is not None:
try:
content = "\n> Executed Results:\n"
for block_idx, code in enumerate(codes):
executed_output = self.code_interpreter.run(
code, code.code_type
)
content += (
f"Executing code block {block_idx}: {{\n"
+ executed_output
+ "}\n"
)
except InterruptedError as e:
content = (
f"\n> Running code fail: {e}\n"
"Please regenerate the code."
)
# TODO: Handle errors
content = input_message.content + f"\n> Embodied Actions:\n{content}"
message = BaseMessage(
input_message.role_name,
input_message.role_type,
input_message.meta_dict,
content,
)
return ChatAgentResponse(
msgs=[message],
terminated=response.terminated,
info=response.info,
)

View File

@@ -0,0 +1,278 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from unstructured.documents.elements import Element
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.storages.graph_storages.graph_element import (
GraphElement,
Node,
Relationship,
)
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
text_prompt = """
You are tasked with extracting nodes and relationships from given content and
structures them into Node and Relationship objects. Here's the outline of what
you needs to do:
Content Extraction:
You should be able to process input content and identify entities mentioned
within it.
Entities can be any noun phrases or concepts that represent distinct entities
in the context of the given content.
Node Extraction:
For each identified entity, you should create a Node object.
Each Node object should have a unique identifier (id) and a type (type).
Additional properties associated with the node can also be extracted and
stored.
Relationship Extraction:
You should identify relationships between entities mentioned in the content.
For each relationship, create a Relationship object.
A Relationship object should have a subject (subj) and an object (obj) which
are Node objects representing the entities involved in the relationship.
Each relationship should also have a type (type), and additional properties if
applicable.
Output Formatting:
The extracted nodes and relationships should be formatted as instances of the
provided Node and Relationship classes.
Ensure that the extracted data adheres to the structure defined by the classes.
Output the structured data in a format that can be easily validated against
the provided code.
Do not wrap the output in lists or dictionaries, provide the Node and
Relationship with unique identifiers.
Strictly follow the format provided in the example output, do not add any
additional information.
Instructions for you:
Read the provided content thoroughly.
Identify distinct entities mentioned in the content and categorize them as
nodes.
Determine relationships between these entities and represent them as directed
relationships.
Provide the extracted nodes and relationships in the specified format below.
Example for you:
Example Content:
"John works at XYZ Corporation. He is a software engineer. The company is
located in New York City."
Expected Output:
Nodes:
Node(id='John', type='Person')
Node(id='XYZ Corporation', type='Organization')
Node(id='New York City', type='Location')
Relationships:
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
Corporation', type='Organization'), type='WorksAt')
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
type='Location'), type='ResidesIn')
===== TASK =====
Please extracts nodes and relationships from given content and structures them
into Node and Relationship objects.
{task}
"""
@track_agent(name="KnowledgeGraphAgent")
class KnowledgeGraphAgent(ChatAgent):
r"""An agent that can extract node and relationship information for
different entities from given `Element` content.
Attributes:
task_prompt (TextPrompt): A prompt for the agent to extract node and
relationship information for different entities.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
r"""Initialize the `KnowledgeGraphAgent`.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
system_message = BaseMessage(
role_name="Graphify",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="Your mission is to transform unstructured content "
"into structured graph data. Extract nodes and relationships with "
"precision, and let the connections unfold. Your graphs will "
"illuminate the hidden connections within the chaos of "
"information.",
)
super().__init__(system_message, model=model)
def run(
self,
element: "Element",
parse_graph_elements: bool = False,
prompt: Optional[str] = None,
) -> Union[str, GraphElement]:
r"""Run the agent to extract node and relationship information.
Args:
element (Element): The input element.
parse_graph_elements (bool, optional): Whether to parse into
`GraphElement`. Defaults to `False`.
prompt (str, optional): The custom prompt to be used.
Defaults to `None`.
Returns:
Union[str, GraphElement]: The extracted node and relationship
information. If `parse_graph_elements` is `True` then return
`GraphElement`, else return `str`.
"""
self.reset()
self.element = element
# Use the provided prompt or fall back to the default text_prompt
final_prompt = prompt if prompt is not None else text_prompt
knowledge_graph_prompt = TextPrompt(final_prompt)
knowledge_graph_generation = knowledge_graph_prompt.format(
task=str(element)
)
response = self.step(input_message=knowledge_graph_generation)
content = response.msg.content
if parse_graph_elements:
content = self._parse_graph_elements(content)
return content
def _validate_node(self, node: Node) -> bool:
r"""Validate if the object is a valid Node.
Args:
node (Node): Object to be validated.
Returns:
bool: True if the object is a valid Node, False otherwise.
"""
return (
isinstance(node, Node)
and isinstance(node.id, (str, int))
and isinstance(node.type, str)
)
def _validate_relationship(self, relationship: Relationship) -> bool:
r"""Validate if the object is a valid Relationship.
Args:
relationship (Relationship): Object to be validated.
Returns:
bool: True if the object is a valid Relationship, False otherwise.
"""
return (
isinstance(relationship, Relationship)
and self._validate_node(relationship.subj)
and self._validate_node(relationship.obj)
and isinstance(relationship.type, str)
)
def _parse_graph_elements(self, input_string: str) -> GraphElement:
r"""Parses graph elements from given content.
Args:
input_string (str): The input content.
Returns:
GraphElement: The parsed graph elements.
"""
import re
# Regular expressions to extract nodes and relationships
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
rel_pattern = (
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
r"obj=Node\(id='(.*?)', type='(.*?)'\), "
r"type='(.*?)'(?:, timestamp='(.*?)')?\)"
)
nodes = {}
relationships = []
# Extract nodes
for match in re.finditer(node_pattern, input_string):
id, type = match.groups()
properties = {'source': 'agent_created'}
if id not in nodes:
node = Node(id=id, type=type, properties=properties)
if self._validate_node(node):
nodes[id] = node
# Extract relationships
for match in re.finditer(rel_pattern, input_string):
groups = match.groups()
if len(groups) == 6:
subj_id, subj_type, obj_id, obj_type, rel_type, timestamp = (
groups
)
else:
subj_id, subj_type, obj_id, obj_type, rel_type = groups
timestamp = None
properties = {'source': 'agent_created'}
if subj_id in nodes and obj_id in nodes:
subj = nodes[subj_id]
obj = nodes[obj_id]
relationship = Relationship(
subj=subj,
obj=obj,
type=rel_type,
timestamp=timestamp,
properties=properties,
)
if self._validate_relationship(relationship):
relationships.append(relationship)
return GraphElement(
nodes=list(nodes.values()),
relationships=relationships,
source=self.element,
)

View File

@@ -0,0 +1,117 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import textwrap
from typing import Any
from pydantic import ConfigDict
from camel.agents.programmed_agent_instruction import (
ProgrammableChatAgent,
ProgrammedAgentInstructionResult,
programmable_capability,
)
from camel.datagen.source2synth.models import (
ContextPrompt,
MultiHopQA,
)
from camel.messages import BaseMessage
class MultiHopGeneratorAgent(ProgrammableChatAgent):
r"""An agent specialized in generating multi-hop question-answer pairs.
This agent is designed to create complex questions that require multiple
steps of reasoning to answer. It analyzes context to identify related
facts and generates questions that require connecting these facts
logically.
Attributes:
model_config (ConfigDict): Configuration for model behavior.
system_message (BaseMessage): System message defining agent's role and
instructions.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **kwargs: Any) -> None:
r"""Initialize the MultiHopGeneratorAgent.
Args:
**kwargs (Any): Additional keyword arguments to pass to parent
class.
"""
super().__init__(**kwargs)
system_text: str = textwrap.dedent(
"""\
You are an expert at generating
multi-hop question-answer pairs.
For each context, you should:
1. Identify multiple related facts or pieces of information
2. Create questions that require reasoning across these multiple pieces
3. Ensure the reasoning chain is clear and logical
4. Generate questions that require at least 2-3 steps of reasoning
5. Include the reasoning steps in the answer
Give your response with this information:
Question: [Complex question requiring multiple reasoning steps]
Reasoning Steps:
1. [First reasoning step]
2. [Second reasoning step]
3. [Final reasoning step]
Answer: [Final answer]
Supporting Facts: [List of relevant text segments used]
""" # noqa: E501
)
self._system_message = BaseMessage.make_assistant_message(
role_name='Assistant', content=system_text
)
@programmable_capability
def generate_multi_hop_qa(
self, context: str
) -> ProgrammedAgentInstructionResult[MultiHopQA]:
r"""Generate a multi-hop question-answer pair from given context.
Args:
context (str): The input text context to generate QA from.
Returns:
ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
generated question, reasoning steps, answer, and supporting
facts.
Raises:
RuntimeError: If the agent fails to generate a response.
"""
context_prompt = ContextPrompt(
main_context=context, related_contexts=None
)
user_message = BaseMessage.make_user_message(
content=context_prompt.model_dump_json(), role_name="User"
)
response = self.step(
input_message=user_message, response_format=MultiHopQA
)
value = MultiHopQA.model_validate_json(response.msgs[0].content)
if response.msgs:
return ProgrammedAgentInstructionResult(
user_message=user_message,
agent_message=response.msgs[0],
value=value,
)
raise RuntimeError("No response from agent")

View File

@@ -0,0 +1,203 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import abc
import threading
from enum import Enum
from functools import wraps
from typing import Any, Callable, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict
from camel.agents import ChatAgent
from camel.messages import BaseMessage
T = TypeVar('T')
class ProgrammableAgentRequirement(Enum):
r"""Requirements for programmable agent state.
Defines the possible requirements that can be used to repair the state
of a programmable agent.
Attributes:
LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
conversation was not from the user.
"""
LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
r"""Result of a programmable agent instruction execution.
Contains the messages exchanged during execution and the computed value.
The value type is specified by the generic type parameter T.
Attributes:
user_message (BaseMessage): The message sent by the user.
agent_message (BaseMessage): The message sent by the agent.
value (T): The computed result value of type T.
"""
user_message: BaseMessage
agent_message: BaseMessage
value: T
model_config = ConfigDict(arbitrary_types_allowed=True)
class AbstractProgrammableAgent(abc.ABC):
r"""Abstract class for a programmable agent.
A programmable agent is an agent that can be programmed to perform a
specific function or task. This class defines the interface for a
programmable agent.
These methods should be implemented in order to ensure the agent supports
the necessary guarantees to enable a programming interface while
maintaining compatibility in a multi-agent system.
A programmable agent is responsible for providing and maintaining a
programming interface for its functionality.
"""
@abc.abstractmethod
def run_atomic(
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
) -> ProgrammedAgentInstructionResult[T]:
r"""Run an atomic operation on the agent.
An atomic operation is an operation that is guaranteed to
be executed without interruption by any other operation.
Args:
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
operation to execute atomically.
Returns:
ProgrammedAgentInstructionResult[T]: The result of the operation.
Raises:
RuntimeError: If an operation is already in progress.
"""
raise NotImplementedError
@abc.abstractmethod
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
r"""Repair the state of the agent.
Agents may have other non-atomic interfaces, such as a user interface,
or chat between other agents. This method should restore the agent to
a state where it can perform operations according to the specified
requirement.
Args:
requirement (ProgrammableAgentRequirement): The requirement to
repair the state for.
"""
raise NotImplementedError
def programmable_capability(
func: Callable[..., ProgrammedAgentInstructionResult[T]],
) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
r"""Decorator for programmable agent capabilities.
This decorator ensures that the decorated method is executed atomically
and maintains the agent's state guarantees.
Args:
func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
to decorate.
Returns:
Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
method that ensures atomic execution.
"""
@wraps(func)
def wrapper(
self, *args: Any, **kwargs: Any
) -> ProgrammedAgentInstructionResult[T]:
return self.run_atomic(lambda: func(self, *args, **kwargs))
return wrapper
class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
r"""A chat agent that can be programmed to perform specific tasks.
Provides a default implementation of atomic execution using threading locks
and basic state tracking for message roles. Implementing classes need to
provide specific repair logic for their use cases.
Attributes:
_operation_lock (threading.Lock): Lock for ensuring atomic operations.
_last_message_role (Optional[str]): Role of the last message in the
conversation.
"""
def __init__(self, **kwargs: Any) -> None:
r"""Initialize the ProgrammableChatAgent.
Args:
**kwargs (Any): Additional keyword arguments to pass to parent
class.
"""
super().__init__(**kwargs)
self._operation_lock = threading.Lock()
self._last_message_role: Optional[str] = None
def run_atomic(
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
) -> ProgrammedAgentInstructionResult[T]:
r"""Run an atomic operation on the agent.
Ensures thread-safe execution of the callback function by using a lock.
Args:
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
operation to execute atomically.
Returns:
ProgrammedAgentInstructionResult[T]: The result of the operation.
Raises:
RuntimeError: If an operation is already in progress.
"""
if not self._operation_lock.acquire(blocking=False):
raise RuntimeError("Operation already in progress")
try:
result = callback()
self._last_message_role = result.agent_message.role_name
return result
finally:
self._operation_lock.release()
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
r"""Repair the state of the agent.
Implements basic state repair for message role requirements.
Args:
requirement (ProgrammableAgentRequirement): The requirement to
repair the state for.
"""
if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
if self._last_message_role == "user":
raise NotImplementedError(
"Must implement repair for LAST_MESSAGE_NOT_USER"
)

579
camel/agents/repo_agent.py Normal file
View File

@@ -0,0 +1,579 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import time
from enum import Enum, auto
from string import Template
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
from github.MainClass import Github
from pydantic import BaseModel
from camel.agents import ChatAgent
from camel.logger import get_logger
from camel.messages import BaseMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.responses import ChatAgentResponse
from camel.retrievers import VectorRetriever
from camel.types import (
ModelPlatformType,
ModelType,
OpenAIBackendRole,
RoleType,
)
from camel.utils import track_agent
from camel.utils.chunker import CodeChunker
logger = get_logger(__name__)
class ProcessingMode(Enum):
FULL_CONTEXT = auto()
RAG = auto()
class GitHubFile(BaseModel):
r"""Model to hold GitHub file information.
Attributes:
content (str): The content of the GitHub text.
file_path (str): The path of the file.
html_url (str): The actual url of the file.
"""
content: str
file_path: str
html_url: str
class RepositoryInfo(BaseModel):
r"""Model to hold GitHub repository information.
Attributes:
repo_name (str): The full name of the repository.
repo_url (str): The URL of the repository.
contents (list): A list to hold the repository contents.
"""
repo_name: str
repo_url: str
contents: List[GitHubFile] = []
@track_agent(name="RepoAgent")
class RepoAgent(ChatAgent):
r"""A specialized agent designed to interact with GitHub repositories for
code generation tasks.
The RepoAgent enhances a base ChatAgent by integrating context from
one or more GitHub repositories. It supports two processing modes:
- FULL_CONTEXT: loads and injects full repository content into the
prompt.
- RAG (Retrieval-Augmented Generation): retrieves relevant
code/documentation chunks using a vector store when context
length exceeds a specified token limit.
Attributes:
vector_retriever (VectorRetriever): Retriever used to
perform semantic search in RAG mode. Required if repo content
exceeds context limit.
system_message (Optional[str]): The system message
for the chat agent. (default: :str:`"You are a code assistant
with repo context."`)
repo_paths (Optional[List[str]]): List of GitHub repository URLs to
load during initialization. (default: :obj:`None`)
model (BaseModelBackend): The model backend to use for generating
responses. (default: :obj:`ModelPlatformType.DEFAULT`
with `ModelType.DEFAULT`)
max_context_tokens (Optional[int]): Maximum number of tokens allowed
before switching to RAG mode. (default: :obj:`2000`)
github_auth_token (Optional[str]): GitHub personal access token
for accessing private or rate-limited repositories. (default:
:obj:`None`)
chunk_size (Optional[int]): Maximum number of characters per code chunk
when indexing files for RAG. (default: :obj:`8192`)
top_k (int): Number of top-matching chunks to retrieve from the vector
store in RAG mode. (default: :obj:`5`)
similarity (Optional[float]): Minimum similarity score required to
include a chunk in the RAG context. (default: :obj:`0.6`)
collection_name (Optional[str]): Name of the vector database
collection to use for storing and retrieving chunks. (default:
:obj:`None`)
**kwargs: Inherited from ChatAgent
Note:
The current implementation of RAG mode requires using Qdrant as the
vector storage backend. The VectorRetriever defaults to QdrantStorage
if no storage is explicitly provided. Other vector storage backends
are not currently supported for the RepoAgent's RAG functionality.
"""
def __init__(
self,
vector_retriever: VectorRetriever,
system_message: Optional[
str
] = "You are a code assistant with repo context.",
repo_paths: Optional[List[str]] = None,
model: Optional[BaseModelBackend] = None,
max_context_tokens: int = 2000,
github_auth_token: Optional[str] = None,
chunk_size: Optional[int] = 8192,
top_k: Optional[int] = 5,
similarity: Optional[float] = 0.6,
collection_name: Optional[str] = None,
**kwargs,
):
if model is None:
model = ModelFactory.create(
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
)
super().__init__(system_message=system_message, model=model, **kwargs)
self.max_context_tokens = max_context_tokens
self.vector_retriever = vector_retriever
self.github_auth_token = github_auth_token
self.chunk_size = chunk_size
self.num_tokens = 0
self.processing_mode = ProcessingMode.FULL_CONTEXT
self.top_k = top_k
self.similarity = similarity
self.collection_name = collection_name
self.prompt_template = Template(
"$type: $repo\n"
"You are an AI coding assistant. "
"Your task is to generate code based on provided GitHub "
"repositories. \n"
"### Instructions: \n1. **Analyze the Repositories**: "
"Identify which repositories contain relevant "
"information for the user's request. Ignore unrelated ones.\n"
"2. **Extract Context**: Use code, documentation, "
"dependencies, and tests to understand functionality.\n"
"3. **Generate Code**: Create clean, efficient, and "
"well-structured code that aligns with relevant repositories. \n"
"4. **Justify Output**: Explain which repositories "
"influenced your solution and why others were ignored."
"\n If the repositories lack necessary details, "
"infer best practices and suggest improvements.\n"
"Now, analyze the repositories and generate the "
"required code."
)
self.full_text = ""
self.chunker = CodeChunker(chunk_size=chunk_size or 8192)
self.repos: List[RepositoryInfo] = []
if repo_paths:
self.repos = self.load_repositories(repo_paths)
if len(self.repos) > 0:
self.construct_full_text()
self.num_tokens = self.count_tokens()
if not self.check_switch_mode():
self.update_memory(
message=BaseMessage.make_user_message(
role_name=RoleType.USER.value,
content=self.full_text,
),
role=OpenAIBackendRole.SYSTEM,
)
def parse_url(self, url: str) -> Tuple[str, str]:
r"""Parse the GitHub URL and return the (owner, repo_name) tuple.
Args:
url (str): The URL to be parsed.
Returns:
Tuple[str, str]: The (owner, repo_name) tuple.
"""
try:
url_path = url.replace("https://github.com/", "")
parts = url_path.split("/")
if len(parts) != 2:
raise ValueError("Incorrect GitHub repo URL format.")
else:
return parts[0], parts[1]
except Exception as e:
logger.error(f"Error parsing URL: {e}")
raise Exception(e)
def load_repositories(
self,
repo_urls: List[str],
) -> List[RepositoryInfo]:
r"""Load the content of a GitHub repository.
Args:
repo_urls (str): The list of Repo URLs.
Returns:
List[RepositoryInfo]: A list of objects containing information
about the all repositories, including the contents.
"""
from github.MainClass import Github
github_client = Github(self.github_auth_token)
res = []
for repo_url in repo_urls:
try:
res.append(self.load_repository(repo_url, github_client))
except Exception as e:
logger.error(f"Error loading repository: {e}")
raise Exception(e)
time.sleep(1)
logger.info(f"Successfully loaded {len(res)} repositories.")
return res
def load_repository(
self,
repo_url: str,
github_client: "Github",
) -> RepositoryInfo:
r"""Load the content of a GitHub repository.
Args:
repo_urls (str): The Repo URL to be loaded.
github_client (GitHub): The established GitHub client.
Returns:
RepositoryInfo: The object containing information
about the repository, including the contents.
"""
from github.ContentFile import ContentFile
try:
owner, repo_name = self.parse_url(repo_url)
repo = github_client.get_repo(f"{owner}/{repo_name}")
contents = repo.get_contents("")
except Exception as e:
logger.error(f"Error loading repository: {e}")
raise Exception(e)
info = RepositoryInfo(
repo_name=repo.full_name,
repo_url=repo.html_url,
contents=[],
)
# Create a list to process repository contents
content_list: List[ContentFile] = []
if isinstance(contents, list):
content_list = contents
else:
# Handle single ContentFile case
content_list = [contents]
while content_list:
file = content_list.pop(0)
if file.type == "file":
if any(
file.path.endswith(ext)
for ext in [
".png",
".jpg",
".pdf",
".zip",
".gitignore",
".mp4",
".avi",
".mov",
".mp3",
".wav",
".tar",
".gz",
".7z",
".rar",
".iso",
".gif",
".docx",
]
):
logger.info(f"Skipping binary file: {file.path}")
continue
try:
file_obj = repo.get_contents(file.path)
# Handle file_obj which could be a single ContentFile or a
# list
if isinstance(file_obj, list):
if not file_obj: # Skip empty lists
continue
file_obj = file_obj[
0
] # Take the first item if it's a list
if getattr(file_obj, "encoding", None) != "base64":
logger.warning(
f"Skipping file with unsupported "
f"encoding: {file.path}"
)
continue
try:
content_bytes = file_obj.decoded_content
file_content = content_bytes.decode("utf-8")
except UnicodeDecodeError:
logger.warning(f"Skipping non-UTF-8 file: {file.path}")
continue
except Exception as e:
logger.error(
f"Failed to decode file content at "
f"{file.path}: {e}"
)
continue
github_file = GitHubFile(
content=file_content,
file_path=f"{owner}/{repo_name}/{file.path}",
html_url=file.html_url,
)
info.contents.append(github_file)
except Exception as e:
logger.error(f"Error loading file: {e}")
raise Exception(e)
logger.info(f"Successfully loaded file: {file.path}")
elif file.type == "dir":
dir_contents = repo.get_contents(file.path)
# Handle dir_contents which could be a single ContentFile or a
# list
if isinstance(dir_contents, list):
content_list.extend(dir_contents)
else:
content_list.append(dir_contents)
return info
def count_tokens(self) -> int:
r"""To count the tokens that's currently in the memory
Returns:
int: The number of tokens
"""
counter = self.model_backend.token_counter
content_token_count = counter.count_tokens_from_messages(
messages=[
BaseMessage.make_user_message(
role_name=RoleType.USER.value,
content=self.full_text,
).to_openai_message(OpenAIBackendRole.USER)
]
)
return content_token_count
def construct_full_text(self):
r"""Construct full context text from repositories by concatenation."""
repo_texts = [
{"content": f.content, "path": f.file_path}
for repo in self.repos
for f in repo.contents
]
self.full_text = self.prompt_template.safe_substitute(
type="Repository",
repo="\n".join(
f"{repo['path']}\n{repo['content']}" for repo in repo_texts
),
)
def add_repositories(self, repo_urls: List[str]):
r"""Add a GitHub repository to the list of repositories.
Args:
repo_urls (str): The Repo URL to be added.
"""
new_repos = self.load_repositories(repo_urls)
self.repos.extend(new_repos)
self.construct_full_text()
self.num_tokens = self.count_tokens()
if self.processing_mode == ProcessingMode.RAG:
for repo in new_repos:
for f in repo.contents:
self.vector_retriever.process(
content=f.content,
should_chunk=True,
extra_info={"file_path": f.file_path},
chunker=self.chunker,
)
else:
self.check_switch_mode()
def check_switch_mode(self) -> bool:
r"""Check if the current context exceeds the context window; if so,
switch to RAG mode.
Returns:
bool: True if the mode was switched, False otherwise.
"""
if self.processing_mode == ProcessingMode.RAG:
return False
if self.num_tokens > self.max_context_tokens:
if not self.vector_retriever:
logger.warning(
f"Token count ({self.num_tokens}) exceeds limit "
f"({self.max_context_tokens}). "
"Either reduce repository size or provide a "
"VectorRetriever."
)
return False
logger.info("Switching to RAG mode and indexing repositories...")
self.processing_mode = ProcessingMode.RAG
for repo in self.repos:
for f in repo.contents:
self.vector_retriever.process(
content=f.content,
should_chunk=True,
extra_info={"file_path": f.file_path},
chunker=self.chunker,
)
self._system_message = None
self.reset()
return True
return False
def step(
self, input_message: Union[BaseMessage, str], *args, **kwargs
) -> ChatAgentResponse:
r"""Overrides `ChatAgent.step()` to first retrieve relevant context
from the vector store before passing the input to the language model.
"""
if (
self.processing_mode == ProcessingMode.RAG
and self.vector_retriever
):
if isinstance(input_message, BaseMessage):
user_query = input_message.content
else:
user_query = input_message
retrieved_content = []
retries = 1
for attempt in range(retries):
try:
raw_rag_content = self.vector_retriever.query(
query=user_query,
top_k=self.top_k or 5,
similarity_threshold=self.similarity or 0.6,
)
# Remove duplicates and retrieve the whole file
paths = []
for record in raw_rag_content:
file_path = record["extra_info"]["file_path"]
if file_path not in paths:
retrieved_content.append(
{
"content": self.search_by_file_path(
file_path
),
"similarity": record["similarity score"],
}
)
paths.append(file_path)
retrieved_content = sorted(
retrieved_content,
key=lambda x: x["similarity"],
reverse=True,
)
full_prompt = self.prompt_template.safe_substitute(
type="Retrieved code",
repo="\n".join(
[record["content"] for record in retrieved_content]
),
)
new_query = user_query + "\n" + full_prompt
if isinstance(input_message, BaseMessage):
input_message.content = new_query
else:
input_message = BaseMessage.make_user_message(
role_name="User", content=new_query
)
break
except Exception:
if attempt < retries - 1:
sleep_time = 2**attempt
logger.info(
f"Retrying qdrant query in {sleep_time} seconds..."
)
time.sleep(sleep_time)
else:
logger.error(
f"Failed to query qdrant record after {retries} "
"attempts."
)
return super().step(input_message, *args, **kwargs)
def reset(self):
super().reset()
if self.processing_mode == ProcessingMode.FULL_CONTEXT:
message = BaseMessage.make_user_message(
role_name=RoleType.USER.value,
content=self.full_text,
)
self.update_memory(message, OpenAIBackendRole.SYSTEM)
else:
self.num_tokens = 0
def search_by_file_path(self, file_path: str) -> str:
r"""Search for all payloads in the vector database where
file_path matches the given value (the same file),
then sort by piece_num and concatenate text fields to return a
complete result.
Args:
file_path (str): The `file_path` value to filter the payloads.
Returns:
str: A concatenated string of the `text` fields sorted by
`piece_num`.
"""
from qdrant_client.models import FieldCondition, Filter, MatchValue
try:
storage_instance = self.vector_retriever.storage
collection_name = (
self.collection_name or storage_instance.collection_name # type: ignore[attr-defined]
)
source_data, _ = storage_instance.client.scroll(
collection_name=collection_name,
limit=1000,
scroll_filter=Filter(
must=[
FieldCondition(
key="extra_info.file_path",
match=MatchValue(value=file_path),
)
]
),
with_payload=True,
with_vectors=False,
)
except Exception as e:
logger.error(
f"Error during database initialization or scroll: {e}"
)
raise Exception(e)
results = []
for point in source_data:
payload = point.payload
piece_num = payload["metadata"]["piece_num"]
text = payload["text"]
if piece_num is not None and text:
results.append({"piece_num": piece_num, "text": text})
sorted_results = sorted(results, key=lambda x: x["piece_num"])
full_doc = "\n".join([item["text"] for item in sorted_results])
return full_doc

View File

@@ -0,0 +1,141 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from typing import Dict, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="RoleAssignmentAgent")
class RoleAssignmentAgent(ChatAgent):
r"""An agent that generates role names based on the task prompt.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
Attributes:
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
role names.
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Role Assigner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You assign roles based on tasks.",
)
super().__init__(system_message, model=model)
def run(
self,
task_prompt: Union[str, TextPrompt],
num_roles: int = 2,
) -> Dict[str, str]:
r"""Generate role names based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt
for the task based on which the roles are to be generated.
num_roles (int, optional): The number of roles to generate.
(default: :obj:`2`)
Returns:
Dict[str, str]: A dictionary mapping role names to their
descriptions.
"""
self.reset()
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
f"Domain expert {i + 1}: <BLANK>\n"
f"Associated competencies, characteristics, duties "
f"and workflows: <BLANK>. End."
for i in range(num_roles or 0)
)
role_assignment_generation_prompt = TextPrompt(
"You are a role assignment agent, and you're in charge of "
+ "recruiting {num_roles} experts for the following task."
+ "\n==== TASK =====\n {task}\n\n"
+ "Identify the domain experts you'd recruit and detail their "
+ "associated competencies, characteristics, duties and workflows "
+ "to complete the task.\n "
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
+ "ONLY answer the BLANKs.\n"
+ expert_prompt
)
role_assignment_generation = role_assignment_generation_prompt.format(
num_roles=num_roles, task=task_prompt
)
role_assignment_generation_msg = BaseMessage.make_user_message(
role_name="Role Assigner", content=role_assignment_generation
)
response = self.step(input_message=role_assignment_generation_msg)
msg = response.msg # type: BaseMessage
terminated = response.terminated
# Distribute the output completions into role names and descriptions
role_names = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Domain expert \d: (.+?)\nAssociated competencies,",
msg.content,
re.DOTALL,
)
]
role_descriptions = [
desc.replace("<|", "").replace("|>", "")
for desc in re.findall(
r"Associated competencies, characteristics, "
r"duties and workflows: (.+?) End.",
msg.content,
re.DOTALL,
)
]
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
raise RuntimeError(
"Got None or insufficient information of roles."
)
if terminated:
raise RuntimeError("Role assignment failed.")
role_descriptions_dict = {
role_name: description
for role_name, description in zip(role_names, role_descriptions)
}
return role_descriptions_dict

View File

@@ -0,0 +1,133 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import TextPrompt
from camel.types import RoleType
from camel.utils import create_chunks
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="SearchAgent")
class SearchAgent(ChatAgent):
r"""An agent that summarizes text based on a query and evaluates the
relevance of an answer.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
) -> None:
system_message = BaseMessage(
role_name="Assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful assistant.",
)
super().__init__(system_message, model=model)
def summarize_text(self, text: str, query: str) -> str:
r"""Summarize the information from the text, base on the query.
Args:
text (str): Text to summarize.
query (str): What information you want.
Returns:
str: Strings with information.
"""
self.reset()
summary_prompt = TextPrompt(
'''Gather information from this text that relative to the
question, but do not directly answer the question.\nquestion:
{query}\ntext '''
)
summary_prompt = summary_prompt.format(query=query)
# Max length of each chunk
max_len = 3000
results = ""
chunks = create_chunks(text, max_len)
# Summarize
for i, chunk in enumerate(chunks, start=1):
prompt = summary_prompt + str(i) + ": " + chunk
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
result = self.step(user_msg).msg.content
results += result + "\n"
# Final summarization
final_prompt = TextPrompt(
'''Here are some summarized texts which split from one text. Using
the information to answer the question. If can't find the answer,
you must answer "I can not find the answer to the query" and
explain why.\n Query:\n{query}.\n\nText:\n'''
)
final_prompt = final_prompt.format(query=query)
prompt = final_prompt + results
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
return response
def continue_search(self, query: str, answer: str) -> bool:
r"""Ask whether to continue search or not based on the provided answer.
Args:
query (str): The question.
answer (str): The answer to the question.
Returns:
bool: `True` if the user want to continue search, `False`
otherwise.
"""
prompt = TextPrompt(
"Do you think the ANSWER can answer the QUERY? "
"Use only 'yes' or 'no' to answer.\n"
"===== QUERY =====\n{query}\n\n"
"===== ANSWER =====\n{answer}"
)
prompt = prompt.format(query=query, answer=answer)
user_msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
)
response = self.step(user_msg).msg.content
if "yes" in str(response).lower():
return False
return True

410
camel/agents/task_agent.py Normal file
View File

@@ -0,0 +1,410 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union
from camel.agents.chat_agent import ChatAgent
from camel.messages import BaseMessage
from camel.models import BaseModelBackend
from camel.prompts import PromptTemplateGenerator, TextPrompt
from camel.types import RoleType, TaskType
from camel.utils import get_task_list
# AgentOps decorator setting
try:
import os
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import track_agent
else:
raise ImportError
except (ImportError, AttributeError):
from camel.utils import track_agent
@track_agent(name="TaskSpecifyAgent")
class TaskSpecifyAgent(ChatAgent):
r"""An agent that specifies a given task prompt by prompting the user to
provide more details.
Attributes:
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
task_specify_prompt (TextPrompt): The prompt for specifying the task.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
task_type (TaskType, optional): The type of task for which to generate
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
specifying the task. (default: :obj:`None`)
word_limit (int, optional): The word limit for the task prompt.
(default: :obj:`50`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
DEFAULT_WORD_LIMIT = 50
def __init__(
self,
model: Optional[BaseModelBackend] = None,
task_type: TaskType = TaskType.AI_SOCIETY,
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
word_limit: int = DEFAULT_WORD_LIMIT,
output_language: Optional[str] = None,
) -> None:
self.task_specify_prompt: Union[str, TextPrompt]
if task_specify_prompt is None:
task_specify_prompt_template = (
PromptTemplateGenerator().get_task_specify_prompt(task_type)
)
self.task_specify_prompt = task_specify_prompt_template.format(
word_limit=word_limit
)
else:
self.task_specify_prompt = TextPrompt(task_specify_prompt)
system_message = BaseMessage(
role_name="Task Specifier",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You can make a task more specific.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
meta_dict: Optional[Dict[str, Any]] = None,
) -> TextPrompt:
r"""Specify the given task prompt by providing more details.
Args:
task_prompt (Union[str, TextPrompt]): The original task
prompt.
meta_dict (Dict[str, Any], optional): A dictionary containing
additional information to include in the prompt.
(default: :obj:`None`)
Returns:
TextPrompt: The specified task prompt.
"""
self.reset()
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
if meta_dict is not None:
task_specify_prompt = task_specify_prompt.format(**meta_dict)
task_msg = BaseMessage.make_user_message(
role_name="Task Specifier", content=task_specify_prompt
)
specifier_response = self.step(task_msg)
if specifier_response.terminated:
raise RuntimeError("Task specification failed.")
if len(specifier_response.msgs) == 0:
raise RuntimeError("Got no specification message.")
specified_task_msg = specifier_response.msgs[0]
return TextPrompt(specified_task_msg.content)
@track_agent(name="TaskPlannerAgent")
class TaskPlannerAgent(ChatAgent):
r"""An agent that helps divide a task into subtasks based on the input
task prompt.
Attributes:
task_planner_prompt (TextPrompt): A prompt for the agent to divide
the task into subtasks.
Args:
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
"""
def __init__(
self,
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
) -> None:
self.task_planner_prompt = TextPrompt(
"Divide this task into subtasks: {task}. Be concise."
)
system_message = BaseMessage(
role_name="Task Planner",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task planner.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
)
def run(
self,
task_prompt: Union[str, TextPrompt],
) -> TextPrompt:
r"""Generate subtasks based on the input task prompt.
Args:
task_prompt (Union[str, TextPrompt]): The prompt for the task to
be divided into subtasks.
Returns:
TextPrompt: A prompt for the subtasks generated by the agent.
"""
# TODO: Maybe include roles information.
self.reset()
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
task_msg = BaseMessage.make_user_message(
role_name="Task Planner", content=task_planner_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task planning failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task planning message.")
sub_tasks_msg = task_response.msgs[0]
return TextPrompt(sub_tasks_msg.content)
@track_agent(name="TaskCreationAgent")
class TaskCreationAgent(ChatAgent):
r"""An agent that helps create new tasks based on the objective
and last completed task. Compared to :obj:`TaskPlannerAgent`,
it's still a task planner, but it has more context information
like last task and incomplete task list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_creation_prompt (TextPrompt): A prompt for the agent to
create new tasks.
Args:
role_name (str): The role name of the Agent to create the task.
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
max_task_num (int, optional): The maximum number of planned
tasks in one round. (default: :obj:3)
"""
def __init__(
self,
role_name: str,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
max_task_num: Optional[int] = 3,
) -> None:
task_creation_prompt = TextPrompt(
"""Create new a task with the following objective: {objective}.
Never forget you are a Task Creator of {role_name}.
You must instruct me based on my expertise and your needs to solve the task.
You should consider past solved tasks and in-progress tasks: {task_list}.
The new created tasks must not overlap with these past tasks.
The result must be a numbered list in the format:
#. First Task
#. Second Task
#. Third Task
You can only give me up to {max_task_num} tasks at a time. \
Each task should be concise, concrete and doable for a {role_name}.
You should make task plan and not ask me questions.
If you think no new tasks are needed right now, write "No tasks to add."
Now start to give me new tasks one by one. No more than three tasks.
Be concrete.
"""
)
self.task_creation_prompt = task_creation_prompt.format(
objective=objective, role_name=role_name, max_task_num=max_task_num
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Creator",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task creator.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Generate subtasks based on the previous task results and
incomplete task list.
Args:
task_list (List[str]): The completed or in-progress
tasks which should not overlap with new created tasks.
Returns:
List[str]: The new task list generated by the Agent.
"""
if len(task_list) > 0:
task_creation_prompt = self.task_creation_prompt.format(
task_list=task_list
)
else:
task_creation_prompt = self.task_creation_prompt.format(
task_list=""
)
task_msg = BaseMessage.make_user_message(
role_name="Task Creator", content=task_creation_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task creation failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task creation message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)
@track_agent(name="TaskPrioritizationAgent")
class TaskPrioritizationAgent(ChatAgent):
r"""An agent that helps re-prioritize the task list and
returns numbered prioritized list. Modified from
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
Attributes:
task_prioritization_prompt (TextPrompt): A prompt for the agent to
prioritize tasks.
Args:
objective (Union[str, TextPrompt]): The objective of the Agent to
perform the task.
model (BaseModelBackend, optional): The LLM backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
"""
def __init__(
self,
objective: Union[str, TextPrompt],
model: Optional[BaseModelBackend] = None,
output_language: Optional[str] = None,
message_window_size: Optional[int] = None,
) -> None:
task_prioritization_prompt = TextPrompt(
"""Prioritize the following tasks : {task_list}.
Consider the ultimate objective of you: {objective}.
Tasks should be sorted from highest to lowest priority, where higher-priority \
tasks are those that act as pre-requisites or are more essential for meeting \
the objective. Return one task per line in your response.
Do not remove or modify any tasks.
The result must be a numbered list in the format:
#. First task
#. Second task
The entries must be consecutively numbered, starting with 1.
The number of each entry must be followed by a period.
Do not include any headers before your ranked list or follow your list \
with any other output."""
)
self.task_prioritization_prompt = task_prioritization_prompt.format(
objective=objective
)
self.objective = objective
system_message = BaseMessage(
role_name="Task Prioritizer",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful task prioritizer.",
)
super().__init__(
system_message,
model=model,
output_language=output_language,
message_window_size=message_window_size,
)
def run(
self,
task_list: List[str],
) -> List[str]:
r"""Prioritize the task list given the agent objective.
Args:
task_list (List[str]): The unprioritized tasks of agent.
Returns:
List[str]: The new prioritized task list generated by the Agent.
"""
task_prioritization_prompt = self.task_prioritization_prompt.format(
task_list=task_list
)
task_msg = BaseMessage.make_user_message(
role_name="Task Prioritizer", content=task_prioritization_prompt
)
task_response = self.step(task_msg)
if task_response.terminated:
raise RuntimeError("Task prioritization failed.")
if len(task_response.msgs) == 0:
raise RuntimeError("Got no task prioritization message.")
sub_tasks_msg = task_response.msgs[0]
return get_task_list(sub_tasks_msg.content)

View File

@@ -0,0 +1,20 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseToolAgent
from .hugging_face_tool_agent import HuggingFaceToolAgent
__all__ = [
'BaseToolAgent',
'HuggingFaceToolAgent',
]

View File

@@ -0,0 +1,39 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.agents import BaseAgent
class BaseToolAgent(BaseAgent):
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
description.
Args:
name (str): The name of the tool agent.
description (str): The description of the tool agent.
"""
def __init__(self, name: str, description: str) -> None:
self.name = name
self.description = description
def reset(self) -> None:
r"""Resets the agent to its initial state."""
pass
def step(self) -> None:
r"""Performs a single step of the agent."""
pass
def __str__(self) -> str:
return f"{self.name}: {self.description}"

View File

@@ -0,0 +1,206 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Optional
from camel.agents.tool_agents.base import BaseToolAgent
# flake8: noqa :E501
class HuggingFaceToolAgent(BaseToolAgent):
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
around agents from the `transformers` library. For more information
about the available models, please see the `transformers` documentation
at https://huggingface.co/docs/transformers/transformers_agents.
Args:
name (str): The name of the agent.
*args (Any): Additional positional arguments to pass to the underlying
Agent class.
remote (bool, optional): Flag indicating whether to run the agent
remotely. (default: :obj:`True`)
**kwargs (Any): Additional keyword arguments to pass to the underlying
Agent class.
"""
def __init__(
self,
name: str,
*args: Any,
remote: bool = True,
**kwargs: Any,
) -> None:
try:
# TODO: Support other tool agents
import transformers
from packaging import version
if version.parse(transformers.__version__) < version.parse(
"4.31.0"
):
raise ValueError(
"The version of \"transformers\" package should >= 4.31.0"
)
from transformers.tools import OpenAiAgent
from transformers.tools.agent_types import AgentImage
except (ImportError, ValueError):
raise ValueError(
"Could not import transformers tool agents. "
"Please setup the environment with "
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
)
self.agent_image_type = AgentImage
self.agent = OpenAiAgent(*args, **kwargs)
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
- Text question answering: given a long text and a question, answer the question in the text
- Unconditional image captioning: Caption the image!
- Image question answering: given an image, answer a question on this image
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
- Text to speech: convert text to speech
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
- Text summarization: summarize a long text in one or a few sentences
- Translation: translate the text into a given language
- Text downloading: to download a text from a web URL
- Text to image: generate an image according to a prompt, leveraging stable diffusion
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
- Text to video: generate a small video according to a prompt
Here are some python code examples of what you can do with this agent:
Single execution (step) mode, the single execution method is when using the step() method of the agent:
```
# Text to image
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
# Text to image -> Image transformation
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
sea_add_island_image.save("./sea_add_island_image.png")
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
# you can do so by specifying variables that you would like the agent to use. For example,
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
picture = {name}.step("Generate a picture of rivers and lakes.")
picture.save("./picture.png")
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
updated_picture.save("./updated_picture.png")
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
capybara_sea_image.save("./capybara_sea_image.png")
# Document question answering
answer = {name}.step(
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
document=document,
)
print(answer)
# Text to image
boat_image = {name}.step("Generate an image of a boat in the water")
boat_image.save("./boat_image.png")
# Unconditional image captioning
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
print(boat_image_caption)
# Text to image -> Unconditional image captioning -> Text to speech
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
# Text downloading
document = {name}.step("Download the text from http://hf.co")
print(document)
# Text summarization
summary = {name}.step("Summarize the following text: `document`", document=document)
print(summary)
# Text downloading -> Text summarization -> Text to speech
audio = {name}.step("Read out loud the summary of http://hf.co")
```
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
```
# Clean the chat history
{name}.reset()
# Text to image
capybara_image = {name}.chat("Show me an an image of a capybara")
capybara_image.save("./capybara_image.png")
# Image transformation
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
transformed_capybara_image.save("./transformed_capybara_image.png")
# Image segmentation
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
```
"""
super(HuggingFaceToolAgent, self).__init__(name, description)
self.remote = remote
def reset(self) -> None:
r"""Resets the chat history of the agent."""
self.agent.prepare_for_new_chat()
def step(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in single execution mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.run(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output
def chat(
self,
*args: Any,
remote: Optional[bool] = None,
**kwargs: Any,
) -> Any:
r"""Runs the agent in a chat conversation mode.
Args:
*args (Any): Positional arguments to pass to the agent.
remote (bool, optional): Flag indicating whether to run the agent
remotely. Overrides the default setting. (default: :obj:`None`)
**kwargs (Any): Keyword arguments to pass to the agent.
Returns:
str: The response from the agent.
"""
if remote is None:
remote = self.remote
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
if isinstance(agent_output, self.agent_image_type):
agent_output = agent_output.to_raw()
return agent_output

View File

@@ -0,0 +1,30 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .apibank import APIBankBenchmark
from .apibench import APIBenchBenchmark
from .base import BaseBenchmark
from .gaia import DefaultGAIARetriever, GAIABenchmark
from .nexus import NexusBenchmark
from .ragbench import RAGBenchBenchmark
__all__ = [
"BaseBenchmark",
"GAIABenchmark",
"DefaultGAIARetriever",
"NexusBenchmark",
"APIBenchBenchmark",
"APIBankBenchmark",
"RAGBenchBenchmark",
]

571
camel/benchmarks/apibank.py Normal file
View File

@@ -0,0 +1,571 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import os
import random
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
import numpy as np
from rouge import Rouge
from tqdm import tqdm
from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
from camel.messages import BaseMessage
from camel.utils import download_github_subdirectory
logger = logging.getLogger(__name__)
# Add current folder to sys.path to enable relative import
current_folder = os.getcwd()
if current_folder not in sys.path:
sys.path.append(current_folder)
def process_messages(
chat_history: List[Dict[str, Any]],
prompt: str,
) -> List[Dict[str, str]]:
"""
Processes chat history into a structured format for further use.
Args:
chat_history (List[Dict[str, Any]):
A list of dictionaries representing the chat history.
prompt (str): A prompt to be set as the system message.
Returns:
List[Dict[str, str]]: A list of dictionaries representing
the processed messages, where each dictionary has:
- 'role': The role of the message ('system', 'user', or 'assistant').
- 'content': The content of the message, including formatted
API responses when applicable.
"""
messages = [{'role': 'system', 'content': prompt}]
for item in chat_history:
role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
chat_role = role_map.get(
item['role'], 'unknown'
) # default role to 'unknown'
if item['role'] == 'API':
chat_content = '[{}({})] Response: {}'.format(
item['api_name'],
', '.join(
[
'{}=\'{}\''.format(k, v)
for k, v in item['param_dict'].items()
]
),
str(item['result']['output']),
)
else:
chat_content = item['text']
messages.append({'role': chat_role, 'content': chat_content})
return messages
class APIBankBenchmark(BaseBenchmark):
r"""API-Bank Benchmark adapted from `API-Bank:
A Comprehensive Benchmark for Tool-Augmented LLMs`
<https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
Args:
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use.
(default: :obj:`1`)
"""
def __init__(
self,
save_to: str,
processes: int = 1,
):
r"""Initialize the APIBank benchmark.
Args:
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use for
parallel processing. (default: :obj:`1`)
"""
# Predefine data_dir for better import management
super().__init__("apibank", "api_bank", save_to, processes)
self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
def download(self):
r"""Download APIBank dataset and code from Github."""
repo = "AlibabaResearch/DAMO-ConvAI"
subdir = "api-bank"
data_dir = self.data_dir
download_github_subdirectory(repo, subdir, data_dir)
sys.path.insert(0, self.data_dir)
logger.info("Download completed.")
def load(self, level: str, force_download: bool = False): # type: ignore[override]
r"""Load the APIBank Benchmark dataset.
Args:
level (str): Level to run benchmark on.
force_download (bool, optional): Whether to
force download the data.
"""
if force_download:
logger.info("Force downloading data.")
self.download()
if level == "level-1":
file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
elif level == 'level-2':
file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
jsonl_files = [
f for f in os.listdir(file_path) if f.endswith('.jsonl')
]
for file in tqdm(jsonl_files, desc="Processing files"):
history = []
with open(file_path / file, 'r') as f:
for line in f:
history.append(json.loads(line))
samples = APIBankSample.from_chat_history(history)
self._data[file.rsplit('.', 1)[0]] = samples
# Change import to relative import in the downloaded python files
def process_files(folder_path, replacements):
r"""Replace absolute imports in downloaded files with
relative import."""
for file in os.listdir(folder_path):
if file.endswith(".py"):
file_path = os.path.join(folder_path, file)
try:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
original_content = content
for pattern, replacement in replacements:
content = re.sub(pattern, replacement, content)
if content != original_content:
with open(
file_path, "w", encoding="utf-8"
) as file:
file.write(content)
logger.info(f"Updated file: {file_path}")
except Exception as e:
logger.info(f"Error processing file {file_path}: {e}")
api_bank_folder = "api_bank"
apis_folder = os.path.join(api_bank_folder, "apis")
apis_replacements = [
(r"from apis.api", "from .api"),
(r"from apis import", "from .api import"),
]
api_bank_replacements = [
(r"from apis", "from .apis"),
(r"from api_call_extraction", "from .api_call_extraction"),
(r"f'{basename}", r"f'api_bank.{basename}"),
]
process_files(apis_folder, apis_replacements)
process_files(api_bank_folder, api_bank_replacements)
def run( # type: ignore[override, return]
self,
agent: ChatAgent,
level: Literal["level-1", "level-2"],
api_test_enabled=True,
randomize: bool = False,
subset: Optional[int] = None,
) -> Dict[str, Any]:
r"""Run the benchmark.
Args:
agent (ChatAgent): The agent to run the
benchmark.
level (Literal['level-1', 'level-2']):
The level to run the benchmark on.
randomize (bool, optional): Whether to
randomize the data.
api_test_enabled (bool): Whether to test
API calling (`True`) or response (`False`)
(default: :obj:`False`)
subset (Optional[int], optional):
The subset of data to run.
(default: :obj:`None`)
Returns:
Dict[str, Any]: The results of the benchmark.
"""
logger.info(f"Running APIBench benchmark on {level}.")
self.load(level)
datas = self._data
# Shuffle and subset data if necessary
if randomize:
randomized_items = list(datas.items())
random.shuffle(randomized_items)
datas = dict(randomized_items)
if subset:
datas = dict(list(datas.items())[:subset])
logger.info(f"Number of tasks: {len(datas)}")
# Initialize results storage
self._results = []
# The following code are adapted from the evaluator
# from the original repo:
tool_search_enabled = level == "level-2"
dialog_test_enabled = not api_test_enabled
total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
with open(self.save_to, "w") as f:
for test in tqdm(datas, desc="Running"):
samples = self._data[test]
evaluator = Evaluator(samples) # type: ignore[arg-type]
for sample_id in evaluator.get_all_sample_ids():
# Process sample and generate response
sample = evaluator.dataset[sample_id]
if (
sample.ground_truth['role'] == 'API'
and api_test_enabled
):
if tool_search_enabled:
_, chat_history = evaluator.get_model_input(
sample_id
)
api_descriptions = evaluator.get_api_description(
'ToolSearcher'
)
else:
api_descriptions, chat_history = (
evaluator.get_model_input(sample_id)
)
messages = process_messages(
chat_history, API_CALL_PROMPT + api_descriptions
)
model_output = agent_call(messages, agent)
api_call = get_api_call(model_output)
# Evaluate API call
if api_call:
try:
correct, model_output_result = (
evaluator.evaluate(sample_id, api_call)
)
except AssertionError as e:
if 'The API name is not correct.' not in str(
e
):
raise e
logging.info('AssertionError: {}'.format(e))
correct = False
else:
model_output_result = 'No API call found'
correct = False
if correct:
correct_api_calls += 1
logging.info(
'Correct API call: {} Ground truth: {}'.format(
api_call, sample.ground_truth
)
)
else:
logging.info(
'Incorrect model output: {} Result: {} \
Ground truth: {} File: {} Sample ID: {} \
Messages: {}'.format(
model_output.replace('\n', ' '),
model_output_result,
sample.ground_truth,
test,
sample_id,
messages[1:],
)
)
total_api_calls += 1
self._results.append(
{
'Role': 'API',
'Model_output': model_output,
'Model_output_result': model_output_result,
'Ground_truth': sample.ground_truth,
'Test': test,
'Correct': correct,
}
)
json_str = json.dumps(
self._results[-1], indent=2, ensure_ascii=False
)
f.write(json_str + "\n")
elif (
sample.ground_truth['role'] == 'AI'
and dialog_test_enabled
):
# Process sample and generate response
api_descriptions, chat_history = (
evaluator.get_model_input(sample_id)
)
messages = process_messages(
chat_history, RESPONSE_PROMPT + api_descriptions
)
model_output = agent_call(messages, agent)
# Evaluate model response
if model_output:
score = evaluator.evaluate(sample_id, model_output)
else:
score = 0
rougel_scores.append(score)
if score < 0.2:
logging.info(
'Low score: {} Score: {} Ground truth: {} \
Test: {} Sample ID: {} \
Messages: {}'.format(
model_output.replace('\n', ' '),
score,
sample.ground_truth,
test,
sample_id,
messages[1:],
)
)
self._results.append(
{
'Role': 'AI',
'Model_output': model_output,
'Score': score,
'Ground_truth': sample.ground_truth,
'Test': test,
}
)
json_str = json.dumps(
self._results[-1], indent=2, ensure_ascii=False
)
f.write(json_str + "\n")
f.flush()
if api_test_enabled:
return {
'total': total_api_calls,
'correct': correct_api_calls,
"accuracy": correct_api_calls / total_api_calls
if total_api_calls
else 0,
}
elif dialog_test_enabled:
return {'Dialog_score': np.mean(rougel_scores)}
# The following code are migrated from the original repo:
# https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
def agent_call(messages: List[Dict], agent: ChatAgent):
r"""Add messages to agent memory and get response."""
for i, msg in enumerate(messages):
if msg['role'] == 'user':
message = BaseMessage.make_user_message(
role_name="CAMEL User", content=msg['content']
)
elif msg['role'] == 'assistant':
message = BaseMessage.make_assistant_message(
role_name="CAMEL Assistant", content=msg['content']
)
elif msg['role'] == 'system':
message = BaseMessage.make_assistant_message(
role_name="System", content=msg['content']
)
else:
raise ValueError(f"Unrecognized role: {msg['role']}")
if i == len(messages) - 1:
break
agent.record_message(message)
response = agent.step(message)
model_output = response.msgs[0].content
agent.reset()
return model_output
def calculate_rouge_l_score(reference, hypothesis):
r"""Calculate rouge l score between hypothesis and reference."""
rouge = Rouge()
scores = rouge.get_scores(hypothesis, reference)
rouge_l_score = scores[0]['rouge-l']['f']
return rouge_l_score
def get_api_call(model_output):
r"""Parse api call from model output."""
api_call_pattern = r"\[(\w+)\((.*)\)\]"
api_call_pattern = re.compile(api_call_pattern)
match = api_call_pattern.search(model_output)
if match:
return match.group(0)
else:
return None
class APIBankSample:
r"""APIBank sample used to load the datasets."""
def __init__(self, chat_history, apis, ground_truth):
self.chat_history = chat_history
self.apis = apis
self.ground_truth = ground_truth
def __repr__(self):
return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
self.chat_history, self.apis, self.ground_truth
)
@classmethod
def from_chat_history(cls, chat_history):
apis = set()
api_positions = []
for i, item in enumerate(chat_history):
if item['role'] == 'API':
apis.add(item['api_name'])
api_positions.append(i)
samples = []
for i in api_positions:
sample = cls(chat_history[:i], apis, chat_history[i])
samples.append(sample)
sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
samples.append(sample)
return samples
class Evaluator:
r"""Evaluator for APIBank benchmark."""
def __init__(self, samples: List[APIBankSample]):
# Place holder for import as the import
# only works after the files have been downloaded
try:
from api_bank.tool_manager import ( # type: ignore[import-not-found]
ToolManager,
)
except Exception as e:
logger.info(f"{e}, Module will be imported after download.")
self.dataset = samples
self.sample_ids = list(range(len(self.dataset)))
os.chdir("api_bank")
self.tool_manager = ToolManager("apis")
os.chdir("..")
def get_all_sample_ids(self):
return self.sample_ids
def get_api_description(self, api_name):
return self.tool_manager.get_api_description(api_name)
def get_model_input(self, sample_id: int):
sample = self.dataset[sample_id]
apis = sample.apis
chat_history = sample.chat_history
api_descriptions = []
for api_name in apis:
api_descriptions.append(
self.tool_manager.get_api_description(api_name)
)
api_description = '\n'.join(api_descriptions)
return api_description, chat_history
def evaluate(self, sample_id, model_output):
try:
from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
parse_api_call,
)
except Exception as e:
logger.info(f"{e}, Module will be imported after download.")
sample = self.dataset[sample_id]
ground_truth = sample.ground_truth
if ground_truth['role'] == 'API':
api_name, param_dict = parse_api_call(model_output)
if api_name != ground_truth['api_name']:
return False, 'API Name Mismatch: {} vs {}'.format(
api_name, ground_truth['api_name']
)
try:
result = self.tool_manager.api_call(api_name, **param_dict)
except Exception as e:
return False, str(e)
api = self.tool_manager.init_tool(api_name)
try:
correct = api.check_api_call_correctness(
result, ground_truth['result']
)
except KeyError:
correct = False
result = 'KeyError' + str(result)
return correct, result
elif ground_truth['role'] == 'AI':
score = calculate_rouge_l_score(ground_truth['text'], model_output)
return round(score, 4)
API_CALL_PROMPT = '''
Based on the given API description and the existing \
conversation history 1..t, please generate the API request \
that the AI should call in step t+1 and output it in the \
format of [ApiName(key1='value1', key2='value2', ...)], \
replace the ApiName with the actual API name, and \
replace the key and value with the actual parameters. \
Your output should start with a square bracket "[" \
and end with a square bracket "]". Do not output any \
other explanation or prompt or the result of the API call in your output.
This year is 2023.
Input:
User: [User's utterence]
AI: [AI's utterence]
Expected output:
[ApiName(key1='value1', key2='value2', ...)]
API descriptions:
'''
RESPONSE_PROMPT = '''
Based on the given API description and the existing \
conversation history 1..t, please generate the next \
dialog that the AI should response after the API call t.
This year is 2023.
Input:
User: [User's utterence]
AI: [AI's utterence]
[ApiName(key1='value1', key2='value2', …)]
Expected output:
AI: [AI's utterence]
API descriptions:
'''

View File

@@ -0,0 +1,499 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, Literal, Optional
import tree_sitter_python as tspython
from tqdm import tqdm
from tree_sitter import Language, Parser
from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
from camel.utils import download_github_subdirectory
logger = logging.getLogger(__name__)
# Mapping of dataset names to file names
# 'Oracle' retriever used here which means all the full
# API documentation will be included in the prompt
dataset_mapping = {
"huggingface": {
"api": "huggingface_api.jsonl",
"eval": "huggingface_eval.json",
"train": "huggingface_train.json",
"questions": "questions_huggingface_oracle.jsonl",
},
"tensorflowhub": {
"api": "tensorflowhub_api.jsonl",
"eval": "tensorflow_eval.json",
"train": "tensorflow_train.json",
"questions": "questions_tensorflowhub_oracle.jsonl",
},
"torchhub": {
"api": "torchhub_api.jsonl",
"eval": "torchhub_eval.json",
"train": "torchhub_train.json",
"questions": "questions_torchhub_oracle.jsonl",
},
}
# This function is migrated from the original repo:
# https://github.com/ShishirPatil/gorilla
def encode_question(question: str, dataset_name: str) -> str:
r"""Encode multiple prompt instructions into a single string."""
if dataset_name == "torchhub":
domains = "1. $DOMAIN is inferred from the task description and \
should include one of {Classification, Semantic Segmentation, \
Object Detection, Audio Separation, Video Classification, \
Text-to-Speech}."
elif dataset_name == "huggingface":
domains = "1. $DOMAIN should include one of {Multimodal Feature \
Extraction, Multimodal Text-to-Image, Multimodal \
Image-to-Text, Multimodal Text-to-Video, \
Multimodal Visual Question Answering, Multimodal Document \
Question Answer, Multimodal Graph Machine Learning, \
Computer Vision Depth Estimation, Computer Vision Image \
Classification, Computer Vision Object Detection, \
Computer Vision Image Segmentation, Computer Vision \
Image-to-Image, Computer Vision Unconditional \
Image Generation, Computer Vision Video Classification, \
Computer Vision Zero-Shor Image Classification, \
Natural Language Processing Text Classification, \
Natural Language Processing Token Classification, \
Natural Language Processing Table Question Answering, \
Natural Language Processing Question Answering, \
Natural Language Processing, Zero-Shot Classification \
Natural Language Processing Translation, Natural Language \
Processing Summarization, Natural Language Processing \
Conversational, Natural Language Processing Text \
Generation, Natural Language Processing Fill-Mask, \
Natural Language Processing Text2Text Generation, \
Natural Language Processing Sentence Similarity, \
Audio Text-to-Speech, Audio Automatic Speech Recognition, \
Audio Audio-to-Audio, Audio Audio Classification, \
Audio Voice Activity Detection, Tabular Tabular \
Classification, Tabular Tabular Regression, \
Reinforcement Learning Reinforcement Learning, \
Reinforcement Learning Robotics }"
elif dataset_name == "tensorflowhub":
domains = "1. $DOMAIN is inferred from the task description \
and should include one of {text-sequence-alignment, \
text-embedding, text-language-model, text-preprocessing, \
text-classification, text-generation, text-question-answering, \
text-retrieval-question-answering, text-segmentation, \
text-to-mel, image-classification, image-feature-vector, \
image-object-detection, image-segmentation, \
image-generator, image-pose-detection, image-rnn-agent, \
image-augmentation, image-classifier, image-style-transfer, \
image-aesthetic-quality, image-depth-estimation, \
image-super-resolution, image-deblurring, image-extrapolation, \
image-text-recognition, image-dehazing, image-deraining, \
image-enhancemenmt, image-classification-logits, \
image-frame-interpolation, image-text-detection, image-denoising, \
image-others, video-classification, video-feature-extraction, \
video-generation, video-audio-text, video-text, \
audio-embedding, audio-event-classification, audio-command-detection, \
audio-paralinguists-classification, audio-speech-to-text, \
audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
else:
logger.info("Error: API name is not supported.")
prompt = (
question
+ "\nWrite a python program in 1 to 2 lines to call API in "
+ dataset_name
+ ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
<<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
<<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
Here are the requirements:\n"
+ domains
+ "\n2. The $API_CALL should have only 1 line of code \
that calls api.\n 3. The $API_PROVIDER should be the \
programming framework used.\n4. $EXPLANATION should be \
a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
Do not repeat the format in your answer."
)
return prompt
class APIBenchBenchmark(BaseBenchmark):
r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
Connected with Massive APIs`
<https://huggingface.co/datasets/gorilla-llm/APIBench>.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use.
(default: :obj:`1`)
"""
# TODO: Integrate retriever (pending)
def __init__(
self,
data_dir: str,
save_to: str,
processes: int = 1,
):
r"""Initialize the APIBench benchmark.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use for
parallel processing. (default: :obj:`1`)
"""
super().__init__("apibench", data_dir, save_to, processes)
def download(self):
r"""Download the APIBench dataset."""
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="gorilla-llm/APIBench",
repo_type="dataset",
local_dir=self.data_dir,
local_dir_use_symlinks=True,
)
repo = "ShishirPatil/gorilla"
subdir = "/gorilla/eval/eval-data/questions"
data_dir = self.data_dir
download_github_subdirectory(repo, subdir, data_dir)
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
r"""Load the APIBench Benchmark dataset.
Args:
dataset_name (str): Name of the specific dataset to be loaded.
force_download (bool, optional): Whether to force
download the data. (default: :obj:`False`)
"""
if force_download:
logger.info("Force downloading data.")
self.download()
def load_json_lines(file_path: Path):
r"""Helper function to load JSON lines from a file."""
try:
with open(file_path, "r") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
except json.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON in file {file_path}: {e}"
)
dataset_path = self.data_dir / dataset_name
if not dataset_path.exists():
raise FileNotFoundError(
f"Dataset directory does not exist: {dataset_path}"
)
for label in ['api', 'eval', 'questions']:
file_name = dataset_mapping[dataset_name][label]
file_path = (
dataset_path / file_name
if label == 'questions'
else self.data_dir / file_name
)
# Load data based on label type
if label in ['api', 'questions', 'eval']:
data = load_json_lines(file_path)
if label == 'eval':
# Extract 'api_data' specifically for eval label
data = [item['api_data'] for item in data]
self._data[label] = data
else:
raise ValueError(f"Unknown label: {label}")
ast_database = []
for data in self._data['api']:
ast_tree = ast_parse(data['api_call'])
ast_database.append(ast_tree)
self._data['ast'] = ast_database
def run( # type: ignore[override]
self,
agent: ChatAgent,
dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
randomize: bool = False,
subset: Optional[int] = None,
) -> Dict[str, Any]:
r"""Run the benchmark.
Args:
agent (ChatAgent): The agent to run the
benchmark.
dataset_name (Literal["huggingface",
"tensorflowhub", "torchhub"]):
The dataset to run the benchmark.
randomize (bool, optional): Whether to randomize the data.
(default: :obj:`False`)
subset (Optional[int], optional): The subset of data to run.
(default: :obj:`None`)
"""
if dataset_name not in dataset_mapping:
raise ValueError(f"Invalid value for dataset: {dataset_name}.")
logger.info(f"Running APIBench benchmark on {dataset_name}.")
self.load(dataset_name)
datas = self._data['questions']
# Shuffle and subset data if necessary
if randomize:
random.shuffle(datas)
if subset:
datas = datas[:subset]
logger.info(f"Number of tasks: {len(datas)}")
# Initialize results storage
self._results = []
with open(self.save_to, "w") as f:
for question in tqdm(datas, desc="Running"):
prompt = encode_question(question["text"], dataset_name)
try:
# Generate response
responses = agent.step(prompt)
response = responses.msgs[0].content
api_database = self._data['api']
qa_pairs = self._data['eval']
ast_database = self._data['ast']
question_id = question['question_id']
# Evaluate response
error, correct, hallucination = evaluate_response(
response,
question_id,
dataset_name,
api_database,
qa_pairs,
ast_database,
)
self._results.append(
{
"question": question,
"agent_response": response,
"correct": correct,
"hallucination": hallucination,
"error": str(error) if error else None,
}
)
except Exception as e:
logger.warning(
f"Error in processing task: {question}: {e}"
)
self._results.append(
{
"question": question,
"agent_response": None,
"correct": False,
"hallucination": False,
"error": str(e),
}
)
agent.reset()
json_str = json.dumps(
self._results[-1], indent=2, ensure_ascii=False
)
f.write(json_str + "\n")
f.flush()
total = len(self._results)
correct = sum(r["correct"] for r in self.results)
hallucination = sum(r["hallucination"] for r in self.results)
return {
"total": total,
"correct": correct,
"hallucination": hallucination,
"accuracy": correct / total if total else "N/A",
"hallucination rate": hallucination / total if total else "N/A",
}
# This code is modified from the
# evaluators in the original repo
# https://github.com/ShishirPatil/gorilla
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
# text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append(
[
str(cur_node),
cur_depth,
cur_node,
cur_node.children[0].text,
]
)
else:
sub_tree_sexp_list.append(
[str(cur_node), cur_depth, cur_node, None]
)
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate):
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node, dataset_name):
if node.child_count == 0:
return []
args_list = []
if dataset_name == "huggingface":
for child in node.children[0].children[0].children[1].children:
if "=" in child.text.decode():
args_list.append(child.children[2].text)
elif (
child.text.decode() != "("
and child.text.decode() != ")"
and child.text.decode() != ","
):
args_list.append(child.text)
elif dataset_name == "tensorflowhub":
for child in node.children[0].children[0].children[1].children:
if (
'model=' in child.text.decode()
or 'model =' in child.text.decode()
):
args_list.append(child.children[2].text)
elif (
child.text.decode() != "("
and child.text.decode() != ")"
and child.text.decode() != ","
):
args_list.append(child.text)
elif dataset_name == "torchhub":
for child in node.children[0].children[0].children[1].children:
if (
"repo_or_dir" in child.text.decode()
or "model" in child.text.decode()
):
args_list.append(child.children[2].text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree, dataset_name)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if (
arg.decode().lstrip("'").rstrip("'")
not in candidate_tree.text.decode()
):
ast_match = False
break
if ast_match:
return idx
return -1
def evaluate_response(
response, question_id, dataset_name, api_database, qa_pairs, ast_database
):
try:
# Index the "api_call" domain
output = response.split("api_call")
if len(output) == 1:
api_call = output[0]
else:
# Parse the output
output = output[1].split("api_provider")[0]
if ":" not in output:
start = 0
else:
start = output.index(":")
if ")" not in output:
end = -2
else:
end = output.rindex(")")
api_call = output[start + 2 : end + 1]
try:
ast_tree = ast_parse(api_call)
except Exception as parse_error:
print(f"Error parsing api_call: {api_call}, error: {parse_error}")
return parse_error, False, False
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(
ast_subtree_list, ast_database, dataset_name
)
# We cannot index this ast in our database
if database_index == -1:
halluncination = True
correct = False
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
correct = True
halluncination = False
else:
return None, False, False
except Exception as e:
print(f'Error parsing response: {response}, error: {e}')
return e, False, False
return None, correct, halluncination

152
camel/benchmarks/base.py Normal file
View File

@@ -0,0 +1,152 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from camel.agents import ChatAgent
logger = logging.getLogger(__name__)
class BaseBenchmark(ABC):
r"""Base class for benchmarks.
Attributes:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
def __init__(
self, name: str, data_dir: str, save_to: str, processes: int = 1
):
r"""Initialize the benchmark.
Args:
name (str): Name of the benchmark.
data_dir (str): Path to the data directory.
save_to (str): Path to save the results.
processes (int): Number of processes to use for parallel
processing. :(default: :obj:`1`)
"""
self.name = name
self.data_dir = Path(data_dir)
self.processes = processes
self.save_to = save_to
if not self.data_dir.exists():
logger.info(
f"Data directory {data_dir} does not exist. Creating it."
)
self.data_dir.mkdir(parents=True, exist_ok=True)
if not self.data_dir.is_dir():
raise NotADirectoryError(
f"Data directory {data_dir} is not a directory"
)
self._data: Dict[str, List[Dict[str, Any]]] = dict()
self._results: List[Dict[str, Any]] = []
@abstractmethod
def download(self) -> "BaseBenchmark":
r"""Download the benchmark data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@abstractmethod
def load(self, force_download: bool = False) -> "BaseBenchmark":
r"""Load the benchmark data.
Args:
force_download (bool): Whether to force download the data.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def train(self) -> List[Dict[str, Any]]:
r"""Get the training data.
Returns:
List[Dict[str, Any]]: The training data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["train"]
@property
def valid(self) -> List[Dict[str, Any]]:
r"""Get the validation data.
Returns:
List[Dict[str, Any]]: The validation data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["valid"]
@property
def test(self) -> List[Dict[str, Any]]:
r"""Get the test data.
Returns:
List[Dict[str, Any]]: The test data.
"""
if not self._data:
logger.info("Data not loaded. Loading data.")
self.load()
return self._data["test"]
@abstractmethod
def run(
self,
agent: ChatAgent,
on: Literal["train", "valid", "test"],
randomize: bool = False,
subset: Optional[int] = None,
*args,
**kwargs,
) -> "BaseBenchmark":
r"""Run the benchmark.
Args:
agent (ChatAgent): The chat agent.
on (str): The data split to run the benchmark on.
randomize (bool): Whether to randomize the data.
subset (int): The subset of the data to run the benchmark on.
Returns:
BaseBenchmark: The benchmark instance.
"""
pass
@property
def results(self) -> List[Dict[str, Any]]:
r"""Get the results.
Returns:
List[Dict[str, Any]]: The results.
"""
return self._results

482
camel/benchmarks/gaia.py Normal file
View File

@@ -0,0 +1,482 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import os
import random
import re
import string
import uuid
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from tqdm import tqdm
from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
from camel.messages import BaseMessage
from camel.retrievers.auto_retriever import AutoRetriever
logger = logging.getLogger(__name__)
class RetrieverProtocol(Protocol):
r"""Protocol for the retriever class. Any retriever class implementing
this protocol can be used in the benchmark class.
"""
def retrieve(
self, query: str, contents: List[str], **kwargs: Dict[str, Any]
) -> Dict[str, Any]:
r"""Retrieve the relevant content for the query.
Args:
query (str): The query to retrieve the content for.
contents (List[str]): The list of contents to search in.
**kwargs (Dict[str, Any]): Additional keyword arguments.
Returns:
Dict[str, Any]: The relevant content for the query.
"""
...
def reset(self, **kwargs) -> bool:
r"""Reset the retriever.
Some benchmarks may require resetting the retriever
after each query.
Args:
**kwargs: Additional keyword arguments.
Returns:
bool: True if the reset was successful, False otherwise.
"""
...
class DefaultGAIARetriever(AutoRetriever):
r"""Default retriever for the GAIA benchmark.
This retriever uses AutoRetriever in camel to retrieve the content based on
the query.
"""
def retrieve(
self, query: str, contents: List[str], **kwargs: Any
) -> Dict[str, Any]:
r"""Retrieve the content based on the query.
Args:
query (str): The query to search for.
contents (List[str]): The list of contents to search from.
**kwargs (Any): The keyword arguments to pass to the
retriever.
Returns:
Dict[str, Any]: The retrieved content.
"""
return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
def reset(self, **kwargs: Any) -> bool:
r"""Reset the retriever.
Args:
**kwargs (Any): The keyword arguments to pass to the
retriever.
Returns:
bool: Whether the reset was successful.
"""
path = Path(self.vector_storage_local_path or os.getcwd())
task_id = str(kwargs.get("task_id", uuid.uuid4()))
retriever_dir = path / task_id
if not retriever_dir.exists():
try:
retriever_dir.mkdir(parents=True)
except Exception as e:
logger.error(
"Error in creating directory: " + f"{retriever_dir}: {e!s}"
)
return False
self.vector_storage_local_path = str(retriever_dir)
return True
class GAIABenchmark(BaseBenchmark):
r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
Assistants"
<https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
retriever (Optional[RetrieverProtocol]): The retriever to use.
(default: :obj:`None`)
processes (int, optional): The number of processes to use.
(default: :obj:`1`)
"""
def __init__(
self,
data_dir: str,
save_to: str,
retriever: Optional[RetrieverProtocol] = None,
processes: int = 1,
):
r"""Initialize the GAIA benchmark.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
retriever (Optional[RetrieverProtocol], optional): The retriever to
use. (default: :obj:`None`)
processes (int, optional): The number of processes to use for
parallel processing. (default: :obj:`1`)
"""
super().__init__("gaia", data_dir, save_to, processes)
self.retriever = retriever or DefaultGAIARetriever()
def download(self):
r"""Download the GAIA dataset."""
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="gaia-benchmark/GAIA",
repo_type="dataset",
local_dir=self.data_dir,
local_dir_use_symlinks=True,
)
def load(self, force_download=False):
r"""Load the GAIA dataset.
Args:
force_download (bool, optional): Whether to
force download the data.
"""
if force_download:
logger.info("Force downloading data.")
self.download()
# Define validation and test directories
valid_dir = self.data_dir / "2023/validation"
test_dir = self.data_dir / "2023/test"
# Check if directories exist; if not, download the data
if not valid_dir.is_dir() or not test_dir.is_dir():
logger.info("Data not found. Downloading data.")
self.download()
# Load metadata for both validation and test datasets
for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
self._data[label] = []
with open(path / "metadata.jsonl", "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
if data["task_id"] == "0-0-0-0-0":
continue
if data["file_name"]:
data["file_name"] = path / data["file_name"]
self._data[label].append(data)
return self
@property
def train(self):
r"""Get the training set."""
raise NotImplementedError("GAIA does not have a training set.")
def run( # type: ignore[override]
self,
agent: ChatAgent,
on: Literal["train", "valid", "test"],
level: Union[int, List[int], Literal["all"]],
randomize: bool = False,
subset: Optional[int] = None,
) -> Dict[str, Any]:
r"""Run the benchmark.
Args:
agent (ChatAgent): The agent to run the benchmark.
on (Literal["valid", "test"]): The set to run the benchmark.
level (Union[int, List[int], Literal["all"]]): The level to run
the benchmark.
randomize (bool, optional): Whether to randomize the data.
(default: :obj:`False`)
subset (Optional[int], optional): The subset of data to run.
(default: :obj:`None`)
Returns:
Dict[str, Any]: The results of the benchmark.
"""
# Validate inputs
if on not in ["valid", "test"]:
raise ValueError(
f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
)
levels = (
[1, 2, 3]
if level == "all"
else [level]
if isinstance(level, int)
else level
)
if not all(
isinstance(level, int) and level in [1, 2, 3] for level in levels
):
raise ValueError(
f"Invalid value for `level`: {level}, expected 1, 2, 3 "
"or 'all'."
)
logger.info(f"Running benchmark on {on} set at levels {levels}.")
datas = [data for data in self._data[on] if data["Level"] in levels]
# Shuffle and subset data if necessary
if randomize:
random.shuffle(datas)
if subset:
datas = datas[:subset]
logger.info(f"Number of tasks: {len(datas)}")
# Initialize results storage
self._results = []
# Process tasks
with open(self.save_to, "w") as f:
for task in tqdm(datas, desc="Running"):
if not self._prepare_task(task):
continue
try:
result = agent.step(self._create_user_message(task))
self._process_result(agent, task, result, f)
except Exception as e:
self._handle_error(task, e, f)
finally:
agent.reset()
return self._generate_summary()
def _prepare_task(self, task: Dict[str, Any]) -> bool:
r"""Prepare the task by validating and enriching its data."""
if task["file_name"]:
file_path = Path(task["file_name"])
if not file_path.exists():
logger.info(
f"Skipping task because file not found: {file_path}"
)
return False
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
if not self.retriever.reset(task_id=task["task_id"]):
return False
retrieved_info = self.retriever.retrieve(
query=task["Question"], contents=[task["file_name"]]
)
retrieved_content = [
item["text"]
for item in retrieved_info.get("Retrieved Context", [])
]
if retrieved_content:
task["Question"] += "\n" + "\n".join(retrieved_content)
else:
logger.info(
f"Skipping task due to unsupported file "
f"format: {file_path.suffix}"
)
return False
return True
def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
r"""Create a user message from a task."""
return BaseMessage.make_user_message(
role_name="User",
content=task["Question"],
)
def _process_result(
self,
agent: ChatAgent,
task: Dict[str, Any],
result: Any,
file_obj: Any,
) -> None:
r"""Process and store the result of a task."""
model_answer = self.get_final_answer(result.msgs[0].content)
final_answer = task["Final answer"]
score = self.question_scorer(model_answer, final_answer)
tool_calls = result.info.get("tool_calls", [])
result_data = {
"task_id": task["task_id"],
"question": task["Question"],
"level": task["Level"],
"model_answer": model_answer,
"ground_truth": final_answer,
"tool_calls": [tool.model_dump() for tool in tool_calls],
"error": None,
"score": int(score),
"history": agent.memory.get_context(),
}
self._results.append(result_data)
file_obj.write(
json.dumps(result_data, indent=2) + "\n", ensure_ascii=False
)
file_obj.flush()
def _handle_error(
self, task: Dict[str, Any], error: Exception, file_obj: Any
) -> None:
r"""Handle errors encountered during task processing."""
logger.warning(f"Error processing task {task['task_id']}: {error}")
error_data = {
"task_id": task["task_id"],
"question": task["Question"],
"level": task["Level"],
"model_answer": "ERROR",
"ground_truth": task["Final answer"],
"tool_calls": [],
"error": str(error),
"score": 0,
}
self._results.append(error_data)
file_obj.write(
json.dumps(error_data, indent=2) + "\n", ensure_ascii=False
)
file_obj.flush()
def _generate_summary(self) -> Dict[str, Any]:
r"""Generate and return a summary of the benchmark results."""
return {
"total": len(self._results),
"correct": sum(result["score"] for result in self._results),
"results": self._results,
}
def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
r"""Scorer for the GAIA benchmark.
https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
scorer.py
Args:
model_answer (str): The model answer.
ground_truth (str): The ground truth answer.
Returns:
bool: The score of the model
"""
def is_float(element: Any) -> bool:
try:
float(element)
return True
except ValueError:
return False
if is_float(ground_truth):
logger.info(f"Evaluating {model_answer} as a number.")
normalized_answer = self.normalize_number_str(model_answer)
return normalized_answer == float(ground_truth)
elif any(char in ground_truth for char in [",", ";"]):
logger.info(
f"Evaluating {model_answer} as a comma separated list."
)
gt_elems = self.split_string(ground_truth)
ma_elems = self.split_string(model_answer)
if len(gt_elems) != len(ma_elems):
logger.warning(
"Answer lists have different lengths, returning False.",
UserWarning,
)
return False
comparisons = []
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
if is_float(gt_elem):
normalized_ma_elem = self.normalize_number_str(ma_elem)
comparisons.append(normalized_ma_elem == float(gt_elem))
else:
ma_elem = self.normalize_str(ma_elem, remove_punct=False)
gt_elem = self.normalize_str(gt_elem, remove_punct=False)
comparisons.append(ma_elem == gt_elem)
return all(comparisons)
else:
logger.info(f"Evaluating {model_answer} as a string.")
ma_elem = self.normalize_str(model_answer)
gt_elem = self.normalize_str(ground_truth)
return ma_elem == gt_elem
def normalize_number_str(self, number_str: str) -> float:
for char in ["$", "%", ","]:
number_str = number_str.replace(char, "")
try:
return float(number_str)
except ValueError:
logger.error(
f"String {number_str} cannot be normalized to number str."
)
return float("inf")
def split_string(
self, s: str, char_list: Optional[List[str]] = None
) -> list[str]:
r"""Split a string based on a list of characters.
Args:
s (str): The string to split.
char_list (Optional[List[str]], optional): T
he list of characters to split on.
(default: :obj:`None`)
"""
if char_list is None:
char_list = [",", ";"]
pattern = f"[{''.join(char_list)}]"
return re.split(pattern, s)
def normalize_str(self, input_str, remove_punct=True) -> str:
r"""Normalize a string.
Args:
input_str: The input string to normalize.
remove_punct: Whether to remove punctuation.
Returns:
str: The normalized string.
"""
no_spaces = re.sub(r"\s", "", input_str)
if remove_punct:
translator = str.maketrans("", "", string.punctuation)
return no_spaces.lower().translate(translator)
else:
return no_spaces.lower()
def get_final_answer(self, content: str) -> str:
r"""Get the final answer from the content.
Args:
content (str): The content to extract the final answer from.
Returns:
str: The final answer.
"""
final_answer_index = content.find("FINAL ANSWER")
if final_answer_index == -1:
return "FINAL ANSWER not found"
start_index = final_answer_index + len("FINAL ANSWER: ")
final_answer_content = content[start_index:].strip()
return final_answer_content

517
camel/benchmarks/nexus.py Normal file
View File

@@ -0,0 +1,517 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import ast
import json
import logging
import os
import random
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
logger = logging.getLogger(__name__)
# Define the data class
@dataclass
class NexusSample:
r"""Nexus benchmark dataset sample."""
input: str
output: str
@dataclass
class NexusTool:
r"""Nexus benchmark tool"""
function_calls: str
descriptions: str
dataset_mapping = {
"NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
"VirusTotal": "Nexusflow/VirusTotalBenchmark",
"PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
"ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
"OTX": "Nexusflow/OTXAPIBenchmark",
"VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
"VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
"NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
}
TOOL_CALLING_PROMPT = """
You are given multiple functions and a user query.
Please proceed with generating a function call for the function \
with the proper arguments that best answers the given prompt.
Respond with nothing but the function call ONLY, such that I can \
directly execute your function call without any post processing \
necessary from my end. Do not use variables.
If there are more than two function calls, separate them with a semicolon (;).
{tools}
Question: {input}
"""
class NexusBenchmark(BaseBenchmark):
r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
Function Calling Benchmark`
<https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use.
(default: :obj:`1`)
"""
def __init__(
self,
data_dir: str,
save_to: str,
processes: int = 1,
):
r"""Initialize the Nexus Function Calling benchmark.
Args:
data_dir (str): The directory to save the data.
save_to (str): The file to save the results.
processes (int, optional): The number of processes to use for
parallel processing. (default: :obj:`1`)
"""
super().__init__("nexus", data_dir, save_to, processes)
self._data: List[NexusSample] = [] # type: ignore[assignment]
def download(self):
r"""Download the Nexus Functional Calling Benchmark dataset."""
from huggingface_hub import snapshot_download
for dataset_name, repo_id in dataset_mapping.items():
local_dir = self.data_dir / dataset_name
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
local_dir_use_symlinks=True,
)
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
r"""Load the Nexus Benchmark dataset.
Args:
dataset_name (str): Name of the specific dataset to be loaded.
force_download (bool): Whether to force download the data.
"""
def _load_csv_data(dataset_dir: Path) -> List:
r"""Load datasets from CSV files."""
dataset = []
for file_name in os.listdir(dataset_dir):
file_path = dataset_dir / file_name
if file_name.endswith(".csv"):
data = pd.read_csv(file_path)
for _, sample in data.iterrows():
dataset.append(
NexusSample(
sample["Input"], "".join(sample["Output"])
)
)
continue
logger.warning(f"Skipping unsupported file: {file_name}")
return dataset
def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
r"""Load datasets from Parquet files."""
dataset = []
if not data_dir.exists():
raise FileNotFoundError(
f"Data directory '{data_dir}' does not exist."
)
for file_name in os.listdir(data_dir):
file_path = data_dir / file_name
if file_name.endswith(".parquet"):
data = pd.read_parquet(file_path)
dataset.extend(_process_parquet_data(data, dataset_name))
continue
logger.warning(f"Skipping unsupported file: {file_name}")
return dataset
def _process_parquet_data(
data: pd.DataFrame, dataset_name: str
) -> List:
r"""Process data from Parquet files based on dataset name."""
dataset: List = []
dataset_handlers = {
"NVDLibrary": _process_nvdlibrary,
"VirusTotal": _process_simple,
"PlacesAPI": _process_simple,
"ClimateAPI": _process_simple,
"OTX": _process_simple,
"VirusTotal-NestedCalls": _process_nested_calls,
"VirusTotal-ParallelCalls": _process_parallel_calls,
}
if dataset_name not in dataset_handlers:
logger.warning(
f"No specific handler for dataset: {dataset_name}"
)
return dataset
handler = dataset_handlers[dataset_name]
for _, sample in data.iterrows():
processed_sample = handler(sample)
if processed_sample:
dataset.append(processed_sample)
return dataset
def _process_nvdlibrary(sample) -> NexusSample:
r"""Process samples for the NVDLibrary dataset."""
return NexusSample(
sample["Input"], sample["Output"].replace("r = nvdlib.", "")
)
def _process_simple(sample) -> NexusSample:
r"""Process samples for simple datasets (e.g., VirusTotal)."""
return NexusSample(sample["Input"], sample["Output"])
def _process_nested_calls(sample) -> Union[NexusSample, None]:
r"""Process samples for VirusTotal-NestedCalls dataset."""
if len(sample["fncall"]) == 1:
return NexusSample(
sample["generated_question"], "".join(sample["fncall"])
)
return None
def _process_parallel_calls(sample) -> Union[NexusSample, None]:
r"""Process samples for VirusTotal-ParallelCalls dataset."""
if len(sample["fncall"]) > 1:
return NexusSample(
sample["generated_question"], "; ".join(sample["fncall"])
)
return None
if force_download:
logger.info("Force downloading data.")
self.download()
# Validate dataset name
if dataset_name not in dataset_mapping:
available_datasets = list(dataset_mapping.keys())
raise ValueError(
f"Dataset '{dataset_name}' is not recognized. "
f"Available datasets: {available_datasets}"
)
# Get the dataset directory
dataset_dir = self.data_dir / dataset_name
if not dataset_dir.exists():
raise FileNotFoundError(
f"The dataset directory for '{dataset_name}' \
does not exist at {dataset_dir}. "
"Please download it first."
)
# Load the dataset
if dataset_name == "NVDLibrary-NestedCalls":
self._data = _load_csv_data(dataset_dir)
else:
self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
@property
def train(self):
r"""Get the training set."""
raise NotImplementedError(
"Nexus Functional Calling has only a single 'train' set."
)
def run( # type: ignore[override, return]
self,
agent: ChatAgent,
task: Literal[
"NVDLibrary",
"VirusTotal",
"OTX",
"PlacesAPI",
"ClimateAPI",
"VirusTotal-ParallelCalls",
"VirusTotal-NestedCalls",
"NVDLibrary-NestedCalls",
],
randomize: bool = False,
subset: Optional[int] = None,
) -> Dict[str, Any]:
r"""Run the benchmark.
Args:
agent (ChatAgent): The agent to run the benchmark.
task (Literal["NVDLibrary", "VirusTotal", "OTX",
"PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
"VirusTotal-NestedCalls",
"NVDLibrary-NestedCalls"]): The task to run the benchmark.
randomize (bool, optional): Whether to randomize the data.
(default: :obj:`False`)
subset (Optional[int], optional): The subset of data to run.
(default: :obj:`None`)
Returns:
Dict[str, Any]: The results of the benchmark.
"""
if task not in dataset_mapping:
raise ValueError(f"Invalid value for dataset: {task}.")
logger.info(f"Running Nexus Function Calling benchmark on {task}.")
self.load(task)
datas = self._data
# Shuffle and subset data if necessary
if randomize:
random.shuffle(datas)
if subset:
datas = datas[:subset]
logger.info(f"Number of tasks: {len(datas)}")
# Initialize results storage
self._results = []
# Process samples
tools = construct_tool_descriptions(task)
with open(self.save_to, "w") as f:
for sample in tqdm(datas, desc="Running"):
prompt = construct_prompt(input=sample.input, tools=tools)
ground_truth_call = sample.output
try:
# Generate response
response = agent.step(prompt)
agent_call = response.msgs[0].content
# Evaluate response
if agent_call:
result = compare_function_calls(
agent_call=agent_call,
ground_truth_call=ground_truth_call,
)
self._results.append(
{
"input": sample.input,
"agent_call": agent_call,
"ground_truth_call": ground_truth_call,
"result": result,
"error": None,
}
)
except Exception as e:
logger.warning(f"Error in processing task: {sample.input}")
self._results.append(
{
"input": sample.input,
"agent_call": None,
"ground_truth_call": ground_truth_call,
"result": 0,
"error": str(e),
}
)
agent.reset()
json_str = json.dumps(
self._results[-1], indent=2, ensure_ascii=False
)
f.write(json_str + "\n")
f.flush()
total = len(self._results)
correct = sum(r["result"] for r in self._results)
return {
"total": total,
"correct": correct,
"accuracy": correct / total,
}
# Utility functions
def construct_tool_descriptions(dataset_name: str) -> str:
r"""Construct tool descriptions from function definitions and
descriptions."""
tool_dataset_mapping = {
"NVDLibrary": "CVECPE",
"VirusTotal": "VirusTotal",
"PlacesAPI": "Places",
"ClimateAPI": "Climate",
"OTX": "OTX",
"VirusTotal-NestedCalls": "VT_Multi (Nested)",
"VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
"NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
}
if dataset_name not in tool_dataset_mapping:
raise ValueError(
f"Dataset '{dataset_name}' is not recognized. "
f"Available datasets: {list(dataset_mapping.keys())}"
)
# Load the dataset based on the dataset name
dataset = load_dataset(
"Nexusflow/Function_Call_Definitions",
name=tool_dataset_mapping[dataset_name],
)["train"]
# Construct tool descriptions
tools = [
NexusTool(tool["function_calls"], tool["descriptions"])
for tool in dataset
]
# Generate the tool prompt
tool_prompt = "".join(
f"Function:\ndef {tool.function_calls}:\n"
+ "\"\"\"\n"
+ f"{tool.descriptions}\n"
+ "\"\"\"\n"
for tool in tools
)
return tool_prompt
def construct_prompt(input: str, tools: str) -> str:
r"Construct prompt from tools and input."
return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
# Functions for function call evaluation
def parse_function_call(
call: str,
) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
r"""Parse a function call string to extract the function name,
positional arguments, and keyword arguments, including
nested function calls.
Args:
call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
Returns:
tuple: (function_name (str), positional_args (list),
keyword_args (dict)) or (None, None, None).
"""
def preprocess_input(call: str) -> str:
r"""Remove formatting like code blocks and whitespace."""
if call.strip().startswith("```python"):
call = call.strip().removeprefix("```python").removesuffix("```")
return textwrap.dedent(call).strip()
def evaluate_arg(arg):
r"""Recursively evaluate arguments, including nested calls."""
if isinstance(arg, ast.Call):
# Recursively parse nested calls
func_name, args, kwargs = parse_function_call(ast.unparse(arg))
return func_name, args, kwargs
elif isinstance(
arg, ast.Constant
): # Handle literals like numbers, strings, etc.
return arg.value
elif isinstance(arg, ast.List): # Handle list literals
return [evaluate_arg(el) for el in arg.elts]
elif isinstance(arg, ast.Dict): # Handle dictionary literals
return {
evaluate_arg(k): evaluate_arg(v)
for k, v in zip(arg.keys, arg.values)
}
elif isinstance(arg, ast.Tuple): # Handle tuple literals
return tuple(evaluate_arg(el) for el in arg.elts)
else:
return ast.literal_eval(arg) # Safely evaluate other types
call = preprocess_input(call)
parsed_calls = []
try:
# Parse the string into an AST
parsed_calls = call.split(";")
for single_call in parsed_calls:
tree = ast.parse(single_call, mode='eval')
# Ensure it's a function call
if isinstance(tree.body, ast.Call):
# Extract function name
if isinstance(
tree.body.func, ast.Name
): # Simple function call
func_name = tree.body.func.id
elif isinstance(
tree.body.func, ast.Attribute
): # Attribute function call
func_name = (
f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
)
else:
raise ValueError(f"Unsupported function call: {call}")
# Extract positional arguments
args = [evaluate_arg(arg) for arg in tree.body.args]
# Extract keyword arguments
kwargs: Dict[str, Any] = {
kw.arg: evaluate_arg(kw.value)
for kw in tree.body.keywords
if kw.arg is not None
}
logger.info("Valid call.")
return func_name, args, kwargs
else:
raise ValueError(f"Not a valid function call: {call}")
except Exception as e:
logger.info(f"Error parsing call: {call}, {e}")
return None, None, None
def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
r"""Compare the function name and arguments of
agent_call and ground_truth_call.
Args:
agent_call (str): Function call by agent.
ground_truth_call (str): Ground truth function call.
Returns:
- `True` if the function names and arguments match.
- `False` otherwise.
"""
# Parse both calls
agent_parsed = parse_function_call(agent_call)
gt_parsed = parse_function_call(ground_truth_call)
if agent_parsed and gt_parsed:
return agent_parsed == gt_parsed
else:
return False

View File

@@ -0,0 +1,333 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
import numpy as np
from datasets import Dataset, load_dataset
from camel.agents import ChatAgent
from camel.benchmarks import BaseBenchmark
from camel.logger import get_logger
from camel.retrievers import AutoRetriever
logger = get_logger(__name__)
class RagasFields:
r"""Constants for RAGAS evaluation field names."""
INPUT_CONTEXT = "contexts"
INPUT_QUESTION = "question"
INPUT_ANSWER = "answer"
def annotate_dataset(
dataset: Dataset,
context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
answer_call: Optional[Callable[[Dict[str, Any]], str]],
) -> Dataset:
r"""Annotate the dataset by adding context and answers using the provided
functions.
Args:
dataset (Dataset): The input dataset to annotate.
context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
Function to generate context for each example.
answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
generate answer for each example.
Returns:
Dataset: The annotated dataset with added contexts and/or answers.
"""
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
if context_call:
example["contexts"] = context_call(example)
if answer_call:
example["answer"] = answer_call(example)
return example
return dataset.map(process_example)
def rmse(
input_trues: Sequence[float],
input_preds: Sequence[float],
) -> Optional[float]:
r"""Calculate Root Mean Squared Error (RMSE).
Args:
input_trues (Sequence[float]): Ground truth values.
input_preds (Sequence[float]): Predicted values.
Returns:
Optional[float]: RMSE value, or None if inputs have different lengths.
"""
if len(input_trues) != len(input_preds):
logger.warning("Input lengths mismatch in RMSE calculation")
return None
trues = np.array(input_trues)
preds = np.array(input_preds, dtype=float)
# Ignore NaN values in predictions
eval_idx = ~np.isnan(preds)
if not np.any(eval_idx):
logger.warning("No valid predictions for RMSE calculation")
return None
trues = trues[eval_idx]
preds = preds[eval_idx]
return float(np.sqrt(np.mean((preds - trues) ** 2)))
def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
Args:
trues (Sequence[bool]): Ground truth binary values.
preds (Sequence[float]): Predicted probability values.
Returns:
float: AUROC score.
"""
from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
eval_idx = ~np.isnan(preds)
if not np.any(eval_idx):
logger.warning("No valid predictions for AUROC calculation")
return 0.5 # Return random classifier score
return float(
roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
)
def ragas_calculate_metrics(
dataset: Dataset,
pred_context_relevance_field: Optional[str],
pred_faithfulness_field: Optional[str],
metrics_to_evaluate: Optional[List[str]] = None,
ground_truth_context_relevance_field: str = "relevance_score",
ground_truth_faithfulness_field: str = "adherence_score",
) -> Dict[str, Optional[float]]:
r"""Calculate RAGAS evaluation metrics.
Args:
dataset (Dataset): The dataset containing predictions and ground truth.
pred_context_relevance_field (Optional[str]): Field name for predicted
context relevance.
pred_faithfulness_field (Optional[str]): Field name for predicted
faithfulness.
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
ground_truth_context_relevance_field (str): Field name for ground truth
relevance.
ground_truth_faithfulness_field (str): Field name for ground truth
adherence.
Returns:
Dict[str, Optional[float]]: Dictionary of calculated metrics.
"""
metrics_to_evaluate = metrics_to_evaluate or [
"context_relevancy",
"faithfulness",
]
calculated_metrics: Dict[str, Optional[float]] = {}
if (
"context_relevancy" in metrics_to_evaluate
and pred_context_relevance_field
):
trues_relevance = dataset[ground_truth_context_relevance_field]
preds_relevance = dataset[pred_context_relevance_field]
calculated_metrics["relevance_rmse"] = rmse(
trues_relevance, preds_relevance
)
if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
trues_hallucination = ~np.array(
dataset[ground_truth_faithfulness_field]
)
preds_hallucination = 1 - np.array(
dataset[pred_faithfulness_field], dtype=float
)
calculated_metrics["hallucination_auroc"] = auroc(
trues_hallucination.tolist(), preds_hallucination.tolist()
)
return calculated_metrics
def ragas_evaluate_dataset(
dataset: Dataset,
contexts_field_name: Optional[str],
answer_field_name: Optional[str],
metrics_to_evaluate: Optional[List[str]] = None,
) -> Dataset:
r"""Evaluate the dataset using RAGAS metrics.
Args:
dataset (Dataset): Input dataset to evaluate.
contexts_field_name (Optional[str]): Field name containing contexts.
answer_field_name (Optional[str]): Field name containing answers.
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
Returns:
Dataset: Dataset with added evaluation metrics.
"""
from ragas import evaluate # type: ignore[import]
from ragas.metrics import ( # type: ignore[import]
context_relevancy,
faithfulness,
)
metrics_to_evaluate = metrics_to_evaluate or [
"context_relevancy",
"faithfulness",
]
# Rename fields if necessary
if (
contexts_field_name
and contexts_field_name != RagasFields.INPUT_CONTEXT
):
dataset = dataset.rename_column(
contexts_field_name, RagasFields.INPUT_CONTEXT
)
if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
dataset = dataset.rename_column(
answer_field_name, RagasFields.INPUT_ANSWER
)
metrics = []
if "context_relevancy" in metrics_to_evaluate:
metrics.append(context_relevancy)
if "faithfulness" in metrics_to_evaluate:
metrics.append(faithfulness)
ragas_result = evaluate(dataset, metrics=metrics)
return Dataset.from_pandas(ragas_result.to_pandas())
class RAGBenchBenchmark(BaseBenchmark):
r"""RAGBench Benchmark for evaluating RAG performance.
This benchmark uses the rungalileo/ragbench dataset to evaluate
retrieval-augmented generation (RAG) systems. It measures context
relevancy and faithfulness metrics as described in
https://arxiv.org/abs/2407.11005.
Args:
processes (int, optional): Number of processes for parallel processing.
subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
split (str, optional): Dataset split to use (e.g., "test").
"""
def __init__(
self,
processes: int = 1,
subset: Literal[
"covidqa",
"cuad",
"delucionqa",
"emanual",
"expertqa",
"finqa",
"hagrid",
"hotpotqa",
"msmarco",
"pubmedqa",
"tatqa",
"techqa",
] = "hotpotqa",
split: Literal["train", "test", "validation"] = "test",
) -> None:
super().__init__("ragbench", "rag_bench", "", processes)
self.subset = subset
self.split = split
self.dataset: Optional[Dataset] = None
def download(self):
r"""Download the RAGBench dataset."""
try:
self.dataset = load_dataset(
"rungalileo/ragbench", self.subset, split=self.split
)
except Exception as e:
logger.error(f"Failed to download dataset: {e}")
raise
def load(self, force_download: bool = False):
r"""Load the RAGBench dataset.
Args:
force_download (bool, optional): Whether to force download the
data.
"""
if force_download or self.dataset is None:
logger.info(
"%s dataset",
"Force downloading" if force_download else "Loading",
)
self.download()
def run( # type: ignore[override, return]
self,
agent: ChatAgent,
auto_retriever: AutoRetriever,
) -> Dict[str, Optional[float]]:
r"""Run the benchmark evaluation.
Args:
agent (ChatAgent): Chat agent for generating answers.
auto_retriever (AutoRetriever): Retriever for finding relevant
contexts.
Returns:
Dict[str, Optional[float]]: Dictionary of evaluation metrics.
"""
def context_call(example):
retrieved_info = auto_retriever.run_vector_retriever(
query=example['question'],
contents=example['documents'],
top_k=1,
return_detailed_info=True,
similarity_threshold=0.5,
)
return [c['text'] for c in retrieved_info['Retrieved Context']]
def answer_call(example: Dict[str, Any]) -> str:
user_msg = str(example)
assistant_response = agent.step(user_msg)
return assistant_response.msg.content
# Annotate the dataset
annotated_ds = annotate_dataset(
self.dataset, context_call, answer_call
)
evaluated_ds = ragas_evaluate_dataset(
annotated_ds,
contexts_field_name="contexts",
answer_field_name="answer",
metrics_to_evaluate=["context_relevancy", "faithfulness"],
)
return ragas_calculate_metrics(
evaluated_ds,
pred_context_relevance_field="context_relevancy",
pred_faithfulness_field="faithfulness",
)

34
camel/bots/__init__.py Normal file
View File

@@ -0,0 +1,34 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .discord import DiscordApp
from .slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack.slack_app import SlackApp
from .telegram_bot import TelegramBot
__all__ = [
'DiscordApp',
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
'TelegramBot',
]

View File

@@ -0,0 +1,26 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .discord_app import DiscordApp
from .discord_installation import DiscordInstallation
from .discord_store import (
DiscordBaseInstallationStore,
DiscordSQLiteInstallationStore,
)
__all__ = [
"DiscordApp",
"DiscordInstallation",
"DiscordSQLiteInstallationStore",
"DiscordBaseInstallationStore",
]

View File

@@ -0,0 +1,384 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, List, Optional
import discord
import httpx
from fastapi import FastAPI
from camel.bots.discord.discord_installation import DiscordInstallation
from camel.logger import get_logger
from camel.utils import api_keys_required, dependencies_required
from .discord_store import DiscordBaseInstallationStore
if TYPE_CHECKING:
from discord import Message
logger = get_logger(__name__)
TOKEN_URL = "https://discord.com/api/oauth2/token"
USER_URL = "https://discord.com/api/users/@me"
class DiscordApp:
r"""A class representing a Discord app that uses the `discord.py` library
to interact with Discord servers.
This bot can respond to messages in specific channels and only reacts to
messages that mention the bot.
Attributes:
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
provided, the bot will only respond to messages in these channels.
token (Optional[str]): The Discord bot token used for authentication.
"""
@dependencies_required('discord')
@api_keys_required(
[
("token", "DISCORD_BOT_TOKEN"),
]
)
def __init__(
self,
channel_ids: Optional[List[int]] = None,
token: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri: Optional[str] = None,
installation_store: Optional[DiscordBaseInstallationStore] = None,
intents: Optional[discord.Intents] = None,
) -> None:
r"""Initialize the DiscordApp instance by setting up the Discord client
and event handlers.
Args:
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
The bot will only respond to messages in these channels if
provided. (default: :obj:`None`)
token (Optional[str]): The Discord bot token for authentication.
If not provided, the token will be retrieved from the
environment variable `DISCORD_TOKEN`. (default: :obj:`None`)
client_id (str, optional): The client ID for Discord OAuth.
(default: :obj:`None`)
client_secret (Optional[str]): The client secret for Discord OAuth.
(default: :obj:`None`)
redirect_uri (str): The redirect URI for OAuth callbacks.
(default: :obj:`None`)
installation_store (DiscordAsyncInstallationStore): The database
stores all information of all installations.
(default: :obj:`None`)
intents (discord.Intents): The Discord intents of this app.
(default: :obj:`None`)
Raises:
ValueError: If the `DISCORD_BOT_TOKEN` is not found in environment
variables.
"""
self.token = token or os.getenv("DISCORD_BOT_TOKEN")
self.channel_ids = channel_ids
self.installation_store = installation_store
if not intents:
intents = discord.Intents.all()
intents.message_content = True
intents.guilds = True
self._client = discord.Client(intents=intents)
# Register event handlers
self._client.event(self.on_ready)
self._client.event(self.on_message)
# OAuth flow
self.client_id = client_id or os.getenv("DISCORD_CLIENT_ID")
self.client_secret = client_secret or os.getenv(
"DISCORD_CLIENT_SECRET"
)
self.redirect_uri = redirect_uri
self.oauth_flow = bool(
self.client_id
and self.client_secret
and self.redirect_uri
and self.installation_store
)
self.app = FastAPI()
async def start(self):
r"""Asynchronously start the Discord bot using its token.
This method starts the bot and logs into Discord asynchronously using
the provided token. It should be awaited when used in an async
environment.
"""
await self._client.start(self.token)
def run(self) -> None:
r"""Start the Discord bot using its token.
This method starts the bot and logs into Discord synchronously using
the provided token. It blocks execution and keeps the bot running.
"""
self._client.run(self.token) # type: ignore[arg-type]
async def exchange_code_for_token_response(
self, code: str
) -> Optional[str]:
r"""Exchange the authorization code for an access token.
Args:
code (str): The authorization code received from Discord after
user authorization.
Returns:
Optional[str]: The access token if successful, otherwise None.
Raises:
ValueError: If OAuth configuration is incomplete or invalid.
httpx.RequestError: If there is a network issue during the request.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
TOKEN_URL, data=data, headers=headers
)
if response.status_code != 200:
logger.error(f"Failed to exchange code: {response.text}")
return None
response_data = response.json()
return response_data
except (httpx.RequestError, ValueError) as e:
logger.error(f"Error during token fetch: {e}")
return None
async def get_user_info(self, access_token: str) -> Optional[dict]:
r"""Retrieve user information using the access token.
Args:
access_token (str): The access token received from Discord.
Returns:
dict: The user information retrieved from Discord.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
headers = {"Authorization": f"Bearer {access_token}"}
async with httpx.AsyncClient() as client:
user_response = await client.get(USER_URL, headers=headers)
return user_response.json()
async def refresh_access_token(self, refresh_token: str) -> Optional[str]:
r"""Refresh the access token using a refresh token.
Args:
refresh_token (str): The refresh token issued by Discord that
can be used to obtain a new access token.
Returns:
Optional[str]: The new access token if successful, otherwise None.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"redirect_uri": self.redirect_uri,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
async with httpx.AsyncClient() as client:
response = await client.post(TOKEN_URL, data=data, headers=headers)
if response.status_code != 200:
logger.error(f"Failed to refresh token: {response.text}")
return None
response_data = response.json()
return response_data.get("access_token")
async def get_valid_access_token(self, guild_id: str) -> Optional[str]:
r"""Retrieve a valid access token for the specified guild.
This method attempts to retrieve an access token for a specific guild.
If the current access token is expired, it will refresh the token using
the refresh token.
Args:
guild_id (str): The ID of the guild to retrieve the access
token for.
Returns:
Optional[str]: The valid access token if successful,
otherwise None.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
assert self.installation_store is not None
installation = await self.installation_store.find_by_guild(
guild_id=guild_id
)
if not installation:
logger.error(f"No installation found for guild: {guild_id}")
return None
if (
installation.token_expires_at
and datetime.now() >= installation.token_expires_at
):
logger.info(
f"Access token expired for guild: {guild_id}, "
f"refreshing token..."
)
new_access_token = await self.refresh_access_token(
installation.refresh_token
)
if new_access_token:
installation.access_token = new_access_token
installation.token_expires_at = datetime.now() + timedelta(
seconds=3600
)
await self.installation_store.save(installation)
return new_access_token
else:
logger.error(
f"Failed to refresh access token for guild: {guild_id}"
)
return None
return installation.access_token
async def save_installation(
self,
guild_id: str,
access_token: str,
refresh_token: str,
expires_in: int,
):
r"""Save the installation information for a given guild.
Args:
guild_id (str): The ID of the guild where the bot is installed.
access_token (str): The access token for the guild.
refresh_token (str): The refresh token for the guild.
expires_in: (int): The expiration time of the
access token.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
assert self.installation_store is not None
expires_at = datetime.now() + timedelta(seconds=expires_in)
installation = DiscordInstallation(
guild_id=guild_id,
access_token=access_token,
refresh_token=refresh_token,
installed_at=datetime.now(),
token_expires_at=expires_at,
)
await self.installation_store.save(installation)
logger.info(f"Installation saved for guild: {guild_id}")
async def remove_installation(self, guild: discord.Guild):
r"""Remove the installation for a given guild.
Args:
guild (discord.Guild): The guild from which the bot is
being removed.
"""
if not self.oauth_flow:
logger.warning(
"OAuth is not enabled. Missing client_id, "
"client_secret, or redirect_uri."
)
return None
assert self.installation_store is not None
await self.installation_store.delete(guild_id=str(guild.id))
print(f"Bot removed from guild: {guild.id}")
async def on_ready(self) -> None:
r"""Event handler that is called when the bot has successfully
connected to the Discord server.
When the bot is ready and logged into Discord, it prints a message
displaying the bot's username.
"""
logger.info(f'We have logged in as {self._client.user}')
async def on_message(self, message: 'Message') -> None:
r"""Event handler for processing incoming messages.
This method is called whenever a new message is received by the bot. It
will ignore messages sent by the bot itself, only respond to messages
in allowed channels (if specified), and only to messages that mention
the bot.
Args:
message (discord.Message): The message object received from
Discord.
"""
# If the message author is the bot itself,
# do not respond to this message
if message.author == self._client.user:
return
# If allowed channel IDs are provided,
# only respond to messages in those channels
if self.channel_ids and message.channel.id not in self.channel_ids:
return
# Only respond to messages that mention the bot
if not self._client.user or not self._client.user.mentioned_in(
message
):
return
logger.info(f"Received message: {message.content}")
@property
def client(self):
return self._client

View File

@@ -0,0 +1,64 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from datetime import datetime
from typing import Optional
class DiscordInstallation:
r"""Represents an installation of a Discord application in a
specific guild (server).
Attributes:
guild_id (str): The unique identifier for the Discord guild (server)
where the application is installed.
access_token (str): The access token used to authenticate API requests
for the installed application.
refresh_token (str): The token used to refresh the access token when
it expires.
installed_at (datetime): The timestamp indicating when the application
was installed in the guild.
token_expires_at (Optional[datetime]): The optional timestamp
indicating when the access token will expire. Defaults to None
if the token does not have an expiration time.
"""
def __init__(
self,
guild_id: str,
access_token: str,
refresh_token: str,
installed_at: datetime,
token_expires_at: Optional[datetime] = None,
):
r"""Initialize the DiscordInstallation.
Args:
guild_id (str): The unique identifier for the Discord guild
(server) where the application is installed.
access_token (str): The access token used to authenticate API
requests for the installed application.
refresh_token (str): The token used to refresh the access token
when it expires.
installed_at (datetime): The timestamp indicating when the
application was installed in the guild.
token_expires_at (Optional[datetime]): The optional timestamp
indicating when the access token will expire. Defaults to None
if the token does not have an expiration time.
(default: :obj:`None`)
"""
self.guild_id = guild_id
self.access_token = access_token
self.refresh_token = refresh_token
self.installed_at = installed_at
self.token_expires_at = token_expires_at

View File

@@ -0,0 +1,160 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from .discord_installation import DiscordInstallation
class DiscordBaseInstallationStore:
r"""Abstract base class for managing Discord installations.
This class defines the interface for database operations related to storing
and retrieving Discord installation data. Subclasses must implement these
methods to handle database-specific logic.
"""
async def init(self):
r"""Initializes the database connection or structure."""
pass
async def save(self, installation: DiscordInstallation):
r"""Saves or updates a Discord installation record."""
pass
async def find_by_guild(
self, guild_id: str
) -> Optional[DiscordInstallation]:
r"""Finds an installation record by guild ID."""
pass
async def delete(self, guild_id: str):
r"""Deletes an installation record by guild ID."""
pass
class DiscordSQLiteInstallationStore(DiscordBaseInstallationStore):
r"""SQLite-based implementation for managing Discord installations.
This class provides methods for initializing the database, saving,
retrieving, and deleting installation records using SQLite.
Attributes:
database (str): Path to the SQLite database file.
"""
def __init__(self, database: str):
r"""Initializes the SQLite installation store.
Args:
database (str): Path to the SQLite database file.
"""
self.database = database
async def init(self):
r"""Initializes the database by creating the required table if it
does not exist."""
import aiosqlite
async with aiosqlite.connect(self.database) as db:
await db.execute(
"""
CREATE TABLE IF NOT EXISTS discord_installations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
guild_id TEXT NOT NULL UNIQUE,
access_token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
installed_at DATETIME NOT NULL,
token_expires_at DATETIME
);
"""
)
await db.commit()
async def save(self, installation: DiscordInstallation):
r"""Saves a new installation record or updates an existing one.
Args:
installation (DiscordInstallation): The installation data to save.
"""
import aiosqlite
async with aiosqlite.connect(self.database) as db:
await db.execute(
"""
INSERT INTO discord_installations (
guild_id, access_token, refresh_token,
installed_at, token_expires_at
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(guild_id) DO UPDATE SET
access_token = excluded.access_token,
refresh_token = excluded.refresh_token,
token_expires_at = excluded.token_expires_at;
""",
[
installation.guild_id,
installation.access_token,
installation.refresh_token,
installation.installed_at,
installation.token_expires_at,
],
)
await db.commit()
async def find_by_guild(
self, guild_id: str
) -> Optional[DiscordInstallation]:
r"""Finds an installation record by guild ID.
Args:
guild_id (str): The guild ID to search for.
Returns:
Optional[DiscordInstallation]: The installation record if found,
otherwise None.
"""
import aiosqlite
async with aiosqlite.connect(self.database) as db:
async with db.execute(
"SELECT guild_id, access_token, refresh_token, "
"installed_at, token_expires_at FROM discord_installations "
"WHERE guild_id = ?",
[guild_id],
) as cursor:
row = await cursor.fetchone()
if row:
return DiscordInstallation(
guild_id=row[0],
access_token=row[1],
refresh_token=row[2],
installed_at=row[3],
token_expires_at=row[4],
)
return None
async def delete(self, guild_id: str):
r"""Deletes an installation record by guild ID.
Args:
guild_id (str): The guild ID of the record to delete.
"""
import aiosqlite
async with aiosqlite.connect(self.database) as db:
await db.execute(
"DELETE FROM discord_installations WHERE guild_id = ?",
[guild_id],
)
await db.commit()

View File

@@ -0,0 +1,30 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackAuthProfile,
SlackEventBody,
SlackEventProfile,
)
from .slack_app import SlackApp
__all__ = [
'SlackApp',
'SlackAppMentionEventBody',
'SlackAppMentionEventProfile',
'SlackAuthProfile',
'SlackEventBody',
'SlackEventProfile',
]

158
camel/bots/slack/models.py Normal file
View File

@@ -0,0 +1,158 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional
from pydantic import BaseModel
class SlackAuthProfile(BaseModel):
r"""Represents the authorization profile within a Slack event.
Events will contain a single, compact authorizations field that shows one
installation of your app that the event is visible to.
In other words, lists of authorizations will be truncated to one element.
If there's more than one installing party that your app is keeping track
of, it's best not to rely on the single party listed in authorizations to
be any particular one.
To get a full list of who can see events, call the apps.event.
authorizations.list method after obtaining an app-level token. Read more on
the changes here; they have taken effect for existing apps as of
February 24, 2021.
References:
- https://api.slack.com/apis/events-api#authorizations
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
"""
enterprise_id: Optional[str] = None
"""The ID of the enterprise associated with the authorization."""
team_id: str
"""The ID of the team associated with the authorization."""
user_id: str
"""The ID of the user associated with the authorization."""
is_bot: bool
"""Whether the authorized user is a bot."""
is_enterprise_install: bool
"""Whether the authorization is for an enterprise installation."""
class SlackEventProfile(BaseModel):
r"""Represents the detailed profile of a Slack event, including user,
message, and context data.
"""
user: str
"""The ID of the user associated with the event."""
type: str
"""The type of the event (e.g., 'message')."""
ts: str
"""A timestamp representing when the event was triggered."""
thread_ts: Optional[str] = None
"""The timestamp of the parent message in a thread."""
client_msg_id: str
"""A unique ID generated by the client for the message (if available)."""
text: str
"""The message content text."""
team: str
"""The ID of the team that the event is associated with."""
blocks: list
"""The list of message blocks, providing structured information."""
channel: str
"""The ID of the Slack channel where the event happened."""
event_ts: str
"""The event-specific timestamp when it occurred."""
channel_type: Optional[str]
"""The type of Slack channel (e.g., 'channel', 'im')."""
class SlackEventBody(BaseModel):
r"""Represents the entire body of a Slack event, including the event
profile, authorization, and context.
"""
token: str
"""The token to verify the source of the event."""
team_id: str
"""The ID of the team where the event is happening."""
context_team_id: Optional[str]
"""The team ID for the shared channel context, if applicable."""
context_enterprise_id: Optional[str] = None
"""The enterprise ID for the shared channel context, if applicable."""
api_app_id: str
"""The unique identifier for the Slack app that received the event."""
event: SlackEventProfile
"""A detailed profile of the event"""
type: str
"""The overall type of event received (e.g., 'event_callback')."""
event_id: str
"""A unique identifier assigned to this event by Slack."""
event_time: int
"""The timestamp (in seconds) representing when the event was triggered."""
authorizations: Optional[list[SlackAuthProfile]] = None
"""An optional list of authorizations that describe which installation can
see the event."""
is_ext_shared_channel: bool
"""Indicates if the event is part of a shared channel between different
organizations."""
event_context: str
"""A unique string representing the context of the event."""
class SlackAppMentionEventProfile(SlackEventProfile):
r"""Represents the detailed profile of a Slack event where the app was
mentioned in a message.
"""
channel_type: Optional[str] = None
"""The type of Slack channel. it's None for app mentions."""
class SlackAppMentionEventBody(SlackEventBody):
r"""Represents the entire body of a Slack event where the app was mentioned
in a message.
"""
context_team_id: Optional[str] = None
"""A detailed profile of the event. it's None for app mentions."""
event: SlackAppMentionEventProfile
"""A detailed profile of the event"""

View File

@@ -0,0 +1,255 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, Optional
from slack_sdk.oauth.installation_store.async_installation_store import (
AsyncInstallationStore,
)
from starlette import requests, responses
from camel.bots.slack.models import (
SlackAppMentionEventBody,
SlackAppMentionEventProfile,
SlackEventBody,
SlackEventProfile,
)
from camel.utils import dependencies_required
if TYPE_CHECKING:
from slack_bolt.context.async_context import AsyncBoltContext
from slack_bolt.context.say.async_say import AsyncSay
from slack_sdk.web.async_client import AsyncWebClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SlackApp:
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
This class is responsible for initializing and managing the Slack
application by setting up event handlers, running the app server, and
handling events such as messages and mentions from Slack.
Args:
token (Optional[str]): Slack API token for authentication.
scopes (Optional[str]): Slack app scopes for permissions.
signing_secret (Optional[str]): Signing secret for verifying Slack
requests.
client_id (Optional[str]): Slack app client ID.
client_secret (Optional[str]): Slack app client secret.
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
"/slack/oauth_redirect".
installation_store (Optional[AsyncInstallationStore]): The installation
store for handling OAuth installations.
"""
@dependencies_required('slack_bolt')
def __init__(
self,
token: Optional[str] = None,
scopes: Optional[str] = None,
signing_secret: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri_path: str = "/slack/oauth_redirect",
installation_store: Optional[AsyncInstallationStore] = None,
) -> None:
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
and configuring event handlers and OAuth settings.
Args:
token (Optional[str]): The Slack API token.
scopes (Optional[str]): The scopes for Slack app permissions.
signing_secret (Optional[str]): The signing secret for verifying
requests.
client_id (Optional[str]): The Slack app client ID.
client_secret (Optional[str]): The Slack app client secret.
redirect_uri_path (str): The URI path for handling OAuth redirects
(default is "/slack/oauth_redirect").
installation_store (Optional[AsyncInstallationStore]): An optional
installation store for OAuth installations.
"""
from slack_bolt.adapter.starlette.async_handler import (
AsyncSlackRequestHandler,
)
from slack_bolt.app.async_app import AsyncApp
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
self.signing_secret: Optional[str] = signing_secret or os.getenv(
"SLACK_SIGNING_SECRET"
)
self.client_id: Optional[str] = client_id or os.getenv(
"SLACK_CLIENT_ID"
)
self.client_secret: Optional[str] = client_secret or os.getenv(
"SLACK_CLIENT_SECRET"
)
if not all([self.token, self.scopes, self.signing_secret]):
raise ValueError(
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
"environment variables must be set. Get it here: "
"`https://api.slack.com/apps`."
)
# Setup OAuth settings if client ID and secret are provided
if self.client_id and self.client_secret:
self._app = AsyncApp(
oauth_settings=AsyncOAuthSettings(
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
redirect_uri_path=redirect_uri_path,
),
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
else:
# Initialize Slack Bolt AsyncApp with settings
self._app = AsyncApp(
logger=logger,
signing_secret=self.signing_secret,
installation_store=installation_store,
token=self.token,
)
self._handler = AsyncSlackRequestHandler(self._app)
self.setup_handlers()
def setup_handlers(self) -> None:
r"""Sets up the event handlers for Slack events, such as `app_mention`
and `message`.
This method registers the `app_mention` and `on_message` event handlers
with the Slack Bolt app to respond to Slack events.
"""
self._app.event("app_mention")(self.app_mention)
self._app.event("message")(self.on_message)
def run(
self,
port: int = 3000,
path: str = "/slack/events",
host: Optional[str] = None,
) -> None:
r"""Starts the Slack Bolt app server to listen for incoming Slack
events.
Args:
port (int): The port on which the server should run (default is
3000).
path (str): The endpoint path for receiving Slack events (default
is "/slack/events").
host (Optional[str]): The hostname to bind the server (default is
None).
"""
self._app.start(port=port, path=path, host=host)
async def handle_request(
self, request: requests.Request
) -> responses.Response:
r"""Handles incoming requests from Slack through the request handler.
Args:
request (Request): A Starlette request object representing the
incoming request.
Returns:
The response generated by the Slack Bolt handler.
"""
return await self._handler.handle(request)
async def app_mention(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `app_mention` events.
This method is triggered when someone mentions the app in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the app mention.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
event_profile = SlackAppMentionEventProfile(**event)
event_body = SlackAppMentionEventBody(**body)
logger.info(f"app_mention, context: {context}")
logger.info(f"app_mention, client: {client}")
logger.info(f"app_mention, event_profile: {event_profile}")
logger.info(f"app_mention, event_body: {event_body}")
logger.info(f"app_mention, say: {say}")
async def on_message(
self,
context: "AsyncBoltContext",
client: "AsyncWebClient",
event: Dict[str, Any],
body: Dict[str, Any],
say: "AsyncSay",
) -> None:
r"""Event handler for `message` events.
This method is triggered when the app receives a message in Slack.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
client (AsyncWebClient): The Slack Web API client.
event (Dict[str, Any]): The event data for the message.
body (Dict[str, Any]): The full request body from Slack.
say (AsyncSay): A function to send a response back to the channel.
"""
await context.ack()
event_profile = SlackEventProfile(**event)
event_body = SlackEventBody(**body)
logger.info(f"on_message, context: {context}")
logger.info(f"on_message, client: {client}")
logger.info(f"on_message, event_profile: {event_profile}")
logger.info(f"on_message, event_body: {event_body}")
logger.info(f"on_message, say: {say}")
logger.info(f"Received message: {event_profile.text}")
def mention_me(
self, context: "AsyncBoltContext", body: SlackEventBody
) -> bool:
r"""Check if the bot is mentioned in the message.
Args:
context (AsyncBoltContext): The Slack Bolt context for the event.
body (SlackEventBody): The body of the Slack event.
Returns:
bool: True if the bot is mentioned in the message, False otherwise.
"""
message = body.event.text
bot_user_id = context.bot_user_id
mention = f"<@{bot_user_id}>"
return mention in message

View File

@@ -0,0 +1,78 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import TYPE_CHECKING, Optional
from camel.agents import ChatAgent
from camel.utils import dependencies_required
# Conditionally import telebot types only for type checking
if TYPE_CHECKING:
from telebot.types import ( # type: ignore[import-untyped]
Message,
)
class TelegramBot:
r"""Represents a Telegram bot that is powered by an agent.
Attributes:
chat_agent (ChatAgent): Chat agent that will power the bot.
telegram_token (str, optional): The bot token.
"""
@dependencies_required('telebot')
def __init__(
self,
chat_agent: ChatAgent,
telegram_token: Optional[str] = None,
) -> None:
self.chat_agent = chat_agent
if not telegram_token:
self.token = os.getenv('TELEGRAM_TOKEN')
if not self.token:
raise ValueError(
"`TELEGRAM_TOKEN` not found in environment variables. "
"Get it from t.me/BotFather."
)
else:
self.token = telegram_token
import telebot # type: ignore[import-untyped]
self.bot = telebot.TeleBot(token=self.token)
# Register the message handler within the constructor
self.bot.message_handler(func=lambda message: True)(self.on_message)
def run(self) -> None:
r"""Start the Telegram bot."""
print("Telegram bot is running...")
self.bot.infinity_polling()
def on_message(self, message: 'Message') -> None:
r"""Handles incoming messages from the user.
Args:
message (types.Message): The incoming message object.
"""
self.chat_agent.reset()
if not message.text:
return
assistant_response = self.chat_agent.step(message.text)
self.bot.reply_to(message, assistant_response.msg.content)

106
camel/configs/__init__.py Normal file
View File

@@ -0,0 +1,106 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .aiml_config import AIML_API_PARAMS, AIMLConfig
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
from .base_config import BaseConfig
from .bedrock_config import BEDROCK_API_PARAMS, BedrockConfig
from .cohere_config import COHERE_API_PARAMS, CohereConfig
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
from .groq_config import GROQ_API_PARAMS, GroqConfig
from .internlm_config import INTERNLM_API_PARAMS, InternLMConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .lmstudio_config import LMSTUDIO_API_PARAMS, LMStudioConfig
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
from .modelscope_config import MODELSCOPE_API_PARAMS, ModelScopeConfig
from .moonshot_config import MOONSHOT_API_PARAMS, MoonshotConfig
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
from .openrouter_config import OPENROUTER_API_PARAMS, OpenRouterConfig
from .ppio_config import PPIO_API_PARAMS, PPIOConfig
from .qwen_config import QWEN_API_PARAMS, QwenConfig
from .reka_config import REKA_API_PARAMS, RekaConfig
from .samba_config import (
SAMBA_CLOUD_API_PARAMS,
SAMBA_VERSE_API_PARAMS,
SambaCloudAPIConfig,
SambaVerseAPIConfig,
)
from .sglang_config import SGLANG_API_PARAMS, SGLangConfig
from .siliconflow_config import SILICONFLOW_API_PARAMS, SiliconFlowConfig
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
from .yi_config import YI_API_PARAMS, YiConfig
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
__all__ = [
'BaseConfig',
'ChatGPTConfig',
'OPENAI_API_PARAMS',
'AnthropicConfig',
'ANTHROPIC_API_PARAMS',
'GROQ_API_PARAMS',
'GroqConfig',
'LiteLLMConfig',
'LITELLM_API_PARAMS',
'NvidiaConfig',
'NVIDIA_API_PARAMS',
'OllamaConfig',
'OLLAMA_API_PARAMS',
'ZhipuAIConfig',
'ZHIPUAI_API_PARAMS',
'GeminiConfig',
'Gemini_API_PARAMS',
'VLLMConfig',
'VLLM_API_PARAMS',
'SGLangConfig',
'SGLANG_API_PARAMS',
'MistralConfig',
'MISTRAL_API_PARAMS',
'RekaConfig',
'REKA_API_PARAMS',
'SambaVerseAPIConfig',
'SAMBA_VERSE_API_PARAMS',
'SambaCloudAPIConfig',
'SAMBA_CLOUD_API_PARAMS',
'TogetherAIConfig',
'TOGETHERAI_API_PARAMS',
'CohereConfig',
'COHERE_API_PARAMS',
'YiConfig',
'YI_API_PARAMS',
'QwenConfig',
'QWEN_API_PARAMS',
'BedrockConfig',
'BEDROCK_API_PARAMS',
'DeepSeekConfig',
'DEEPSEEK_API_PARAMS',
'PPIOConfig',
'PPIO_API_PARAMS',
'InternLMConfig',
'INTERNLM_API_PARAMS',
'MoonshotConfig',
"MOONSHOT_API_PARAMS",
'ModelScopeConfig',
'MODELSCOPE_API_PARAMS',
'SiliconFlowConfig',
'SILICONFLOW_API_PARAMS',
'AIMLConfig',
'AIML_API_PARAMS',
'OpenRouterConfig',
'OPENROUTER_API_PARAMS',
'LMSTUDIO_API_PARAMS',
'LMStudioConfig',
]

View File

@@ -0,0 +1,81 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NotGiven
class AIMLConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
AIML API.
Args:
temperature (float, optional): Determines the degree of randomness
in the response. (default: :obj:`None`)
top_p (float, optional): The top_p (nucleus) parameter is used to
dynamically adjust the number of choices for each predicted token
based on the cumulative probabilities. (default: :obj:`None`)
n (int, optional): Number of generations to return.
(default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output.
stream (bool, optional): If set, tokens are returned as Server-Sent
Events as they are made available. (default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate.
(default: :obj:`None`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str], NotGiven]] = None
max_tokens: Optional[Union[int, NotGiven]] = None
logit_bias: dict = Field(default_factory=dict)
response_format: Optional[Union[dict, NotGiven]] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
AIML_API_PARAMS = {param for param in AIMLConfig.model_fields.keys()}

View File

@@ -0,0 +1,80 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional
from camel.configs.base_config import BaseConfig
class AnthropicConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API.
See: https://docs.anthropic.com/en/api/messages
Args:
max_tokens (int, optional): The maximum number of tokens to
generate before stopping. Note that Anthropic models may stop
before reaching this maximum. This parameter only specifies the
absolute maximum number of tokens to generate.
(default: :obj:`None`)
stop_sequences (List[str], optional): Custom text sequences that will
cause the model to stop generating. The models will normally stop
when they have naturally completed their turn. If the model
encounters one of these custom sequences, the response will be
terminated and the stop_reason will be "stop_sequence".
(default: :obj:`None`)
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks. Note that even with temperature of 0.0, the
results will not be fully deterministic. (default: :obj:`None`)
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
compute the cumulative distribution over all the options for each
subsequent token in decreasing probability order and cut it off
once it reaches a particular probability specified by `top_p`.
You should either alter `temperature` or `top_p`,
but not both. (default: :obj:`None`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses. (default: :obj:`None`)
stream (bool, optional): Whether to incrementally stream the response
using server-sent events. (default: :obj:`None`)
metadata (dict, optional): An object describing
metadata about the request. Can include user_id as an external
identifier for the user associated with the request.
(default: :obj:`None`)
thinking (dict, optional): Configuration for enabling
Claude's extended thinking. When enabled, responses include
thinking content blocks showing Claude's thinking process.
(default: :obj:`None`)
tool_choice (dict, optional): How the model should
use the provided tools. The model can use a specific tool, any
available tool, decide by itself, or not use tools at all.
(default: :obj:`None`)
"""
max_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stream: Optional[bool] = None
metadata: Optional[dict] = None
thinking: Optional[dict] = None
tool_choice: Optional[dict] = None
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}

View File

@@ -0,0 +1,86 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from abc import ABC
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict, field_validator
class BaseConfig(ABC, BaseModel):
r"""Base configuration class for all models.
This class provides a common interface for all models, ensuring that all
models have a consistent set of attributes and methods.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
frozen=True,
# UserWarning: conflict with protected namespace "model_"
protected_namespaces=(),
)
tools: Optional[List[Any]] = None
"""A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""
@field_validator("tools", mode="before")
@classmethod
def fields_type_checking(cls, tools):
r"""Validate the type of tools in the configuration.
This method ensures that the tools provided in the configuration are
instances of `FunctionTool`. If any tool is not an instance of
`FunctionTool`, it raises a ValueError.
"""
if tools is not None:
from camel.toolkits import FunctionTool
for tool in tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
return tools
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
The dictionary won't contain None values, as some API does not support
None values. (Like tool in OpenAI beta API)
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
# Convert tools to OpenAI tool schema
config_dict["tools"] = (
[tool.get_openai_tool_schema() for tool in self.tools]
if self.tools
else None
)
# Remove None values
return {k: v for k, v in config_dict.items() if v is not None}

View File

@@ -0,0 +1,73 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Dict, Optional, Union
from camel.configs.base_config import BaseConfig
class BedrockConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Args:
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
top_k (int, optional): The number of top tokens to consider.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
reasoning_effort(str, optional): A parameter specifying the level of
reasoning used by certain model types. Valid values are :obj:
`"low"`, :obj:`"medium"`, or :obj:`"high"`. If set, it is only
applied to the model types that support it (e.g., :obj:`o1`,
:obj:`o1mini`, :obj:`o1preview`, :obj:`o3mini`). If not provided
or if the model type does not support it, this parameter is
ignored. (default: :obj:`None`)
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stream: Optional[bool] = None
tool_choice: Optional[Union[Dict[str, str], str]] = None
reasoning_effort: Optional[str] = None
BEDROCK_API_PARAMS = {param for param in BedrockConfig.model_fields.keys()}

View File

@@ -0,0 +1,77 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional
from camel.configs.base_config import BaseConfig
class CohereConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Cohere API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
documents (list, optional): A list of relevant documents that the
model can cite to generate a more accurate reply. Each document is
either a string or document object with content and metadata.
(default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens the model
will generate as part of the response. (default: :obj:`None`)
stop_sequences (List(str), optional): A list of up to 5 strings that
the model will use to stop generation. If the model generates a
string that matches any of the strings in the list, it will stop
generating tokens and return the generated text up to that point
not including the stop sequence. (default: :obj:`None`)
seed (int, optional): If specified, the backend will make a best
effort to sample tokens deterministically, such that repeated
requests with the same seed and parameters should return the same
result. However, determinism cannot be totally guaranteed.
(default: :obj:`None`)
frequency_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. The
higher the value, the stronger a penalty is applied to previously
present tokens, proportional to how many times they have already
appeared in the prompt or prior generation.
(default: :obj:`None`)
presence_penalty (float, optional): Min value of `0.0`, max value of
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
to `frequency_penalty`, except that this penalty is applied
equally to all tokens that have already appeared, regardless of
their exact frequencies. (default: :obj:`None`)
k (int, optional): Ensures only the top k most likely tokens are
considered for generation at each step. Min value of `0`, max
value of `500`. (default: :obj:`None`)
p (float, optional): Ensures that only the most likely tokens, with
total probability mass of `p`, are considered for generation at
each step. If both k and p are enabled, `p` acts after `k`. Min
value of `0.01`, max value of `0.99`. (default: :obj:`None`)
"""
temperature: Optional[float] = None
documents: Optional[list] = None
max_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
seed: Optional[int] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
k: Optional[int] = None
p: Optional[float] = None
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}

View File

@@ -0,0 +1,108 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
class DeepSeekConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
DeepSeek API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`None`)
response_format (object, optional): Specifies the format of the
returned content. The available values are `{"type": "text"}` or
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
will output a standard JSON string.
(default: :obj:`None`)
stream (bool, optional): If set, partial message deltas will be sent.
Tokens will be sent as data-only server-sent events (SSE) as
they become available, with the stream terminated by a
data: [DONE] message. (default: :obj:`None`)
stop (Union[str, list[str]], optional): Up to 16 sequences where
the API will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens that can
be generated in the chat completion. The total length of input
tokens and generated tokens is limited by the model's context
length. (default: :obj:`None`)
presence_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on whether they
appear in the text so far, increasing the model's likelihood
to talk about new topics. (default: :obj:`None`)
frequency_penalty (float, optional): Number between -2.0 and 2.0.
Positive values penalize new tokens based on their existing
frequency in the text so far, decreasing the model's likelihood
to repeat the same line verbatim. (default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use
this to provide a list of functions the model may generate JSON
inputs for. A max of 128 functions are supported.
(default: :obj:`None`)
tool_choice (Union[dict[str, str], str], optional): Controls which
(if any) tool is called by the model. "none" means the model
will not call any tool and instead generates a message. "auto"
means the model can pick between generating a message or calling
one or more tools. "required" means the model must call one or
more tools. Specifying a particular tool via
{"type": "function", "function": {"name": "my_function"}} forces
the model to call that tool. "none" is the default when no tools
are present. "auto" is the default if tools are present.
(default: :obj:`None`)
logprobs (bool, optional): Whether to return log probabilities of
the output tokens or not. If true, returns the log probabilities
of each output token returned in the content of message.
(default: :obj:`None`)
top_logprobs (int, optional): An integer between 0 and 20 specifying
the number of most likely tokens to return at each token
position, each with an associated log probability. logprobs
must be set to true if this parameter is used.
(default: :obj:`None`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`.
(default: :obj:`None`)
"""
temperature: Optional[float] = None # deepseek default: 1.0
top_p: Optional[float] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[Union[Type[BaseModel], dict]] = None
frequency_penalty: Optional[float] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}

View File

@@ -0,0 +1,88 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
class GeminiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Gemini API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: Optional[float] = None # openai default: 1.0
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
response_format: Optional[Union[Type[BaseModel], dict]] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}

View File

@@ -0,0 +1,103 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
class GroqConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Reference: https://console.groq.com/docs/openai
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[dict] = None
frequency_penalty: Optional[float] = None
user: Optional[str] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
GROQ_API_PARAMS = {param for param in GroqConfig.model_fields.keys()}

View File

@@ -0,0 +1,60 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
class InternLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
InternLM API. You can refer to the following link for more details:
https://internlm.intern-ai.org.cn/api/document
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`None`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`None`)
top_p (float, optional): Controls the diversity and focus of the
generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`None`)
max_tokens (int, optional): Allows the model to
generate the maximum number of tokens.
(default: :obj:`None`)
tools (list, optional): Specifies an array of tools that the model can
call. It can contain one or more tool objects. During a function
call process, the model will select one tool from the array.
(default: :obj:`None`)
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
stream: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
INTERNLM_API_PARAMS = {param for param in InternLMConfig.model_fields.keys()}

View File

@@ -0,0 +1,99 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from camel.configs.base_config import BaseConfig
class LiteLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
LiteLLM API.
Args:
timeout (Optional[Union[float, str]], optional): Request timeout.
(default: :obj:`None`)
temperature (Optional[float], optional): Temperature parameter for
controlling randomness. (default: :obj:`None`)
top_p (Optional[float], optional): Top-p parameter for nucleus
sampling. (default: :obj:`None`)
n (Optional[int], optional): Number of completions to generate.
(default: :obj:`None`)
stream (Optional[bool], optional): Whether to return a streaming
response. (default: :obj:`None`)
stream_options (Optional[dict], optional): Options for the streaming
response. (default: :obj:`None`)
stop (Optional[Union[str, List[str]]], optional): Sequences where the
API will stop generating further tokens. (default: :obj:`None`)
max_tokens (Optional[int], optional): Maximum number of tokens to
generate. (default: :obj:`None`)
presence_penalty (Optional[float], optional): Penalize new tokens
based on their existence in the text so far. (default: :obj:`None`)
frequency_penalty (Optional[float], optional): Penalize new tokens
based on their frequency in the text so far. (default: :obj:`None`)
logit_bias (Optional[dict], optional): Modify the probability of
specific tokens appearing in the completion. (default: :obj:`None`)
user (Optional[str], optional): A unique identifier representing the
end-user. (default: :obj:`None`)
response_format (Optional[dict], optional): Response format
parameters. (default: :obj:`None`)
seed (Optional[int], optional): Random seed. (default: :obj:`None`)
tools (Optional[List], optional): List of tools. (default: :obj:`None`)
tool_choice (Optional[Union[str, dict]], optional): Tool choice
parameters. (default: :obj:`None`)
logprobs (Optional[bool], optional): Whether to return log
probabilities of the output tokens. (default: :obj:`None`)
top_logprobs (Optional[int], optional): Number of most likely tokens
to return at each token position. (default: :obj:`None`)
deployment_id (Optional[str], optional): Deployment ID.
(default: :obj:`None`)
extra_headers (Optional[dict], optional): Additional headers for the
request. (default: :obj:`None`)
api_version (Optional[str], optional): API version.
(default: :obj:`None`)
mock_response (Optional[str], optional): Mock completion response for
testing or debugging. (default: :obj:`None`)
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
provider. (default: :obj:`None`)
max_retries (Optional[int], optional): Maximum number of retries.
(default: :obj:`None`)
"""
timeout: Optional[Union[float, str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stream_options: Optional[dict] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[dict] = None
user: Optional[str] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
tool_choice: Optional[Union[str, dict]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
deployment_id: Optional[str] = None
extra_headers: Optional[dict] = None
api_version: Optional[str] = None
mock_response: Optional[str] = None
custom_llm_provider: Optional[str] = None
max_retries: Optional[int] = None
LITELLM_API_PARAMS = {param for param in LiteLLMConfig.model_fields.keys()}

View File

@@ -0,0 +1,94 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
class LMStudioConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[dict] = None
frequency_penalty: Optional[float] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
LMSTUDIO_API_PARAMS = {param for param in LMStudioConfig.model_fields.keys()}

View File

@@ -0,0 +1,79 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from pydantic import field_validator
from camel.configs.base_config import BaseConfig
class MistralConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Mistral API.
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
#TODO: Support stream mode
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5. (default: :obj:`None`)
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. (default: :obj:`None`)
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. (default: :obj:`None`)
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list. (default: :obj:`None`)
random_seed (Optional[int], optional): the random seed to use for
sampling, e.g. 42. (default: :obj:`None`)
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
(default: :obj:`None`)
response_format (Union[Dict[str, str], ResponseFormat): format of the
response.
tool_choice (str, optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"any"` means the
model must call one or more tools. :obj:`"auto"` is the default
value.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[Union[Dict[str, str], Any]] = None
tool_choice: Optional[str] = None
@field_validator("response_format", mode="before")
@classmethod
def fields_type_checking(cls, response_format):
if response_format and not isinstance(response_format, dict):
from mistralai.models import ResponseFormat
if not isinstance(response_format, ResponseFormat):
raise ValueError(
f"The tool {response_format} should be an instance "
"of `mistralai.models.ResponseFormat`."
)
return response_format
MISTRAL_API_PARAMS = {param for param in MistralConfig().model_fields.keys()}

View File

@@ -0,0 +1,59 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
class ModelScopeConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
ModelScope API. You can refer to the following link for more details:
https://www.modelscope.cn/docs/model-service/API-Inference/intro
Args:
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` or
specifying a particular tool via
{"type": "function", "function": {"name": "some_function"}}
can be used to guide the model to use tools more strongly.
(default: :obj:`None`)
max_tokens (int, optional): Specifies the maximum number of tokens
the model can generate. This sets an upper limit, but does not
guarantee that this number will always be reached.
(default: :obj:`None`)
top_p (float, optional): Controls the randomness of the generated
results. Lower values lead to less randomness, while higher
values increase randomness. (default: :obj:`None`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
stream (bool, optional): If True, enables streaming output.
(default: :obj:`None`)
"""
tool_choice: Optional[Union[dict[str, str], str]] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
stream: Optional[bool] = None
MODELSCOPE_API_PARAMS = {
param for param in ModelScopeConfig.model_fields.keys()
}

View File

@@ -0,0 +1,63 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Optional, Union
from camel.configs.base_config import BaseConfig
class MoonshotConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Moonshot API. You can refer to the following link for more details:
https://platform.moonshot.cn/docs/api-reference
Args:
temperature (float, optional): Controls randomness in the response.
Lower values make the output more focused and deterministic.
(default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate.
(default: :obj:`None`)
stream (bool, optional): Whether to stream the response.
(default: :obj:`False`)
tools (list, optional): List of tools that the model can use for
function calling. Each tool should be a dictionary containing
type, function name, description, and parameters.
(default: :obj:`None`)
top_p (float, optional): Controls diversity via nucleus sampling.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message.(default: :obj:`None`)
presence_penalty (float, optional): Penalty for new tokens based on
whether they appear in the text so far.
(default: :obj:`None`)
frequency_penalty (float, optional): Penalty for new tokens based on
their frequency in the text so far.
(default: :obj:`None`)
stop (Optional[Union[str, List[str]]], optional): Up to 4 sequences
where the API will stop generating further tokens.
(default: :obj:`None`)
"""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = None
tools: Optional[list] = None
top_p: Optional[float] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
stop: Optional[Union[str, List[str]]] = None
MOONSHOT_API_PARAMS = {param for param in MoonshotConfig.model_fields.keys()}

View File

@@ -0,0 +1,70 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import List, Optional, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NotGiven
class NvidiaConfig(BaseConfig):
r"""Configuration class for NVIDIA API models.
This class defines the configuration parameters for NVIDIA's language
models, including temperature, sampling parameters, and response format
settings.
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`None`)
temperature (float, optional): Controls randomness in the response.
Higher values make output more random, lower values make it more
deterministic. Range: [0.0, 2.0]. (default: :obj:`None`)
top_p (float, optional): Controls diversity via nucleus sampling.
Range: [0.0, 1.0]. (default: :obj:`None`)
presence_penalty (float, optional): Penalizes new tokens based on
whether they appear in the text so far. Range: [-2.0, 2.0].
(default: :obj:`None`)
frequency_penalty (float, optional): Penalizes new tokens based on
their frequency in the text so far. Range: [-2.0, 2.0].
(default: :obj:`None`)
max_tokens (Union[int, NotGiven], optional): Maximum number of tokens
to generate. If not provided, model will use its default maximum.
(default: :obj:`None`)
seed (Optional[int], optional): Random seed for deterministic sampling.
(default: :obj:`None`)
tools (Optional[List[Dict]], optional): List of tools available to the
model. This includes tools such as a text editor, a calculator, or
a search engine. (default: :obj:`None`)
tool_choice (Optional[str], optional): Tool choice configuration.
(default: :obj:`None`)
stop (Optional[List[str]], optional): List of stop sequences.
(default: :obj:`None`)
"""
stream: Optional[bool] = Field(default=None)
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None)
frequency_penalty: Optional[float] = Field(default=None)
max_tokens: Optional[Union[int, NotGiven]] = Field(default=None)
seed: Optional[int] = Field(default=None)
tool_choice: Optional[str] = Field(default=None)
stop: Optional[List[str]] = Field(default=None)
NVIDIA_API_PARAMS = {param for param in NvidiaConfig.model_fields.keys()}

View File

@@ -0,0 +1,83 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
class OllamaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://github.com/ollama/ollama/blob/main/docs/openai.md
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[Union[Type[BaseModel], dict]] = None
frequency_penalty: Optional[float] = None
OLLAMA_API_PARAMS = {param for param in OllamaConfig.model_fields.keys()}

View File

@@ -0,0 +1,125 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Dict, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
class ChatGPTConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`None`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
reasoning_effort(str, optional): A parameter specifying the level of
reasoning used by certain model types. Valid values are :obj:
`"low"`, :obj:`"medium"`, or :obj:`"high"`. If set, it is only
applied to the model types that support it (e.g., :obj:`o1`,
:obj:`o1mini`, :obj:`o1preview`, :obj:`o3mini`). If not provided
or if the model type does not support it, this parameter is
ignored. (default: :obj:`None`)
parallel_tool_calls (bool, optional): A parameter specifying whether
the model should call tools in parallel or not.
(default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[Union[Type[BaseModel], Dict]] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict] = None
user: Optional[str] = None
tool_choice: Optional[Union[Dict[str, str], str]] = None
reasoning_effort: Optional[str] = None
parallel_tool_calls: Optional[bool] = None
OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}

View File

@@ -0,0 +1,106 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
from camel.types import NotGiven
class OpenRouterConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility.
Reference: https://openrouter.ai/docs/api-reference/parameters
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported. (default: :obj:`None`)
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present. (default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str], NotGiven]] = None
max_tokens: Optional[Union[int, NotGiven]] = None
presence_penalty: Optional[float] = None
response_format: Optional[Union[dict, NotGiven]] = None
frequency_penalty: Optional[float] = None
user: Optional[str] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
OPENROUTER_API_PARAMS = {
param for param in OpenRouterConfig.model_fields.keys()
}

View File

@@ -0,0 +1,102 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Dict, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
class PPIOConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`None`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[Union[Type[BaseModel], Dict]] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict] = None
user: Optional[str] = None
PPIO_API_PARAMS = {param for param in PPIOConfig.model_fields.keys()}

View File

@@ -0,0 +1,91 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from camel.configs.base_config import BaseConfig
class QwenConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Qwen API. You can refer to the following link for more details:
https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
Args:
stream (bool, optional): Whether to stream the response.
(default: :obj:`None`)
temperature (float, optional): Controls the diversity and
focus of the generated results. Lower values make the output more
focused, while higher values make it more diverse.
(default: :obj:`None`)
top_p (float, optional): Controls the diversity and focus of
the generated results. Higher values make the output more diverse,
while lower values make it more focused. (default: :obj:`0.9`)
presence_penalty (float, optional): Controls the repetition
content in the generated results. Positive values reduce the
repetition of content, while negative values increase it.
(default: :obj:`None`)
response_format (Optional[Dict[str, str]], optional): Specifies the
format of the returned content. The available values are
`{"type": "text"}` or `{"type": "json_object"}`. Setting it to
`{"type": "json_object"}` will output a standard JSON string.
(default: :obj:`None`)
max_tokens (Optional[int], optional): Allows the model to
generate the maximum number of tokens.
(default: :obj:`None`)
seed (Optional[int], optional): Sets the seed parameter to make the
text generation process more deterministic, typically used to
ensure that the results are consistent across model runs. By
passing the same seed value (specified by you) in each model call
while keeping other parameters unchanged, the model is likely to
return the same result.
(default: :obj:`None`)
stop (Optional[Union[str, List]], optional): Using the stop parameter,
the model will automatically stop generating text when it is about
to include the specified string or token_id. You can use the stop
parameter to control the output of the model by passing sensitive
words. (default: :obj:`None`)
tools (List, optional): Specifies an array of tools that the model can
call. It can contain one or more tool objects. During a function
call process, the model will select one tool from the array.
(default: :obj:`None`)
extra_body (Optional[Dict[str, Any]], optional): Additional parameters
to be sent to the Qwen API. If you want to enable internet search,
you can set this parameter to `{"enable_search": True}`.
(default: :obj:`None`)
include_usage (bool, optional): When streaming, specifies whether to
include usage information in `stream_options`.
(default: :obj:`None`)
"""
stream: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
presence_penalty: Optional[float] = None
response_format: Optional[Dict[str, str]] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List]] = None
extra_body: Optional[Dict[str, Any]] = None
def __init__(self, include_usage: bool = True, **kwargs):
super().__init__(**kwargs)
# Only set stream_options when stream is True
# Otherwise, it will raise error when calling the API
if self.stream:
self.stream_options = {"include_usage": include_usage}
QWEN_API_PARAMS = {param for param in QwenConfig.model_fields.keys()}

View File

@@ -0,0 +1,69 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
class RekaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Reka API.
Reference: https://docs.reka.ai/api-reference/chat/create
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5. (default: :obj:`None`)
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. (default: :obj:`None`)
top_k (Optional[int], optional): Parameter which forces the model to
only consider the tokens with the `top_k` highest probabilities at
the next step. (default: :obj:`None`)
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. (default: :obj:`None`)
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list. (default: :obj:`None`)
seed (Optional[int], optional): the random seed to use for sampling, e.
g. 42. (default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
use_search_engine (Optional[bool]): Whether to consider using search
engine to complete the request. Note that even if this is set to
`True`, the model might decide to not use search.
(default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
seed: Optional[int] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
use_search_engine: Optional[bool] = None
REKA_API_PARAMS = {param for param in RekaConfig().model_fields.keys()}

View File

@@ -0,0 +1,164 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
class SambaVerseAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
SambaVerse API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`None`)
max_tokens (Optional[int], optional): The maximum number of tokens to
generate, e.g. 100.
(default: :obj:`None`)
repetition_penalty (Optional[float], optional): The parameter for
repetition penalty. 1.0 means no penalty.
(default: :obj:`None`)
stop (Optional[Union[str,list[str]]]): Stop generation if this token
is detected. Or if one of these tokens is detected when providing
a string list.
(default: :obj:`None`)
stream (Optional[bool]): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
Currently SambaVerse API doesn't support stream mode.
(default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
stop: Optional[Union[str, list[str]]] = None
stream: Optional[bool] = None
SAMBA_VERSE_API_PARAMS = {
param for param in SambaVerseAPIConfig().model_fields.keys()
}
class SambaCloudAPIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None
SAMBA_CLOUD_API_PARAMS = {
param for param in SambaCloudAPIConfig().model_fields.keys()
}

View File

@@ -0,0 +1,76 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
class SGLangConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Reference: https://sgl-project.github.io/references/sampling_params.html
Args:
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
stream (bool, optional): Whether to stream the generated output in
chunks. If set to `True`, the response will be streamed as it is
generated. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[Dict[str, Any]], optional): A list of tool definitions
that the model can dynamically invoke. Each tool should be
defined as a dictionary following OpenAI's function calling
specification format. For more details, refer to the OpenAI
documentation. (default: :obj:`None`)
"""
stop: Optional[Union[str, Sequence[str]]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stream: Optional[bool] = None
max_tokens: Optional[int] = None
tools: Optional[Union[List[Dict[str, Any]]]] = None
SGLANG_API_PARAMS = {param for param in SGLangConfig.model_fields.keys()}

View File

@@ -0,0 +1,92 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, Union
from pydantic import BaseModel
from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN
class SiliconFlowConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
SiliconFlow API.
Args:
temperature (float, optional): Determines the degree of randomness
in the response. (default: :obj:`None`)
top_p (float, optional): The top_p (nucleus) parameter is used to
dynamically adjust the number of choices for each predicted token
based on the cumulative probabilities. (default: :obj:`None`)
n (int, optional): Number of generations to return.
(default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. (default: :obj:`None`)
stream (bool, optional): If set, tokens are returned as Server-Sent
Events as they are made available. (default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate.
(default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported. (default: :obj:`None`)
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
response_format: Optional[Union[Type[BaseModel], dict]] = None
frequency_penalty: Optional[float] = None
def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool
tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict
SILICONFLOW_API_PARAMS = {
param for param in SiliconFlowConfig.model_fields.keys()
}

View File

@@ -0,0 +1,100 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
class TogetherAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`{}`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
"""
temperature: Optional[float] = None # openai default: 1.0
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[dict] = None
frequency_penalty: Optional[float] = None
logit_bias: dict = Field(default_factory=dict)
user: Optional[str] = None
TOGETHERAI_API_PARAMS = {
param for param in TogetherAIConfig.model_fields.keys()
}

View File

@@ -0,0 +1,110 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from pydantic import Field
from camel.configs.base_config import BaseConfig
# flake8: noqa: E501
class VLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`None`)
response_format (object, optional): An object specifying the format
that the model must output. Compatible with GPT-4 Turbo and all
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
{"type": "json_object"} enables JSON mode, which guarantees the
message the model generates is valid JSON. Important: when using
JSON mode, you must also instruct the model to produce JSON
yourself via a system or user message. Without this, the model
may generate an unending stream of whitespace until the generation
reaches the token limit, resulting in a long-running and seemingly
"stuck" request. Also note that the message content may be
partially cut off if finish_reason="length", which indicates the
generation exceeded max_tokens or the conversation exceeded the
max context length.
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`None`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`None`)
logit_bias (dict, optional): Modify the likelihood of specified tokens
appearing in the completion. Accepts a json object that maps tokens
(specified by their token ID in the tokenizer) to an associated
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
is added to the logits generated by the model prior to sampling.
The exact effect will vary per model, but values between:obj:` -1`
and :obj:`1` should decrease or increase likelihood of selection;
values like :obj:`-100` or :obj:`100` should result in a ban or
exclusive selection of the relevant token. (default: :obj:`None`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`None`)
logprobs: Whether to return log probabilities of the output tokens or
not. If true, returns the log probabilities of each output token
returned in the `logits` of `message`. (default: :obj:`None`)
top_logprobs: An integer between 0 and 20 specifying the number of
most likely tokens to return at each token position, each with an
associated log probability. `logprobs` must be set to `true` if
this parameter is used. (default: :obj:`None`)
"""
temperature: Optional[float] = None # openai default: 1.0
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[dict] = None
frequency_penalty: Optional[float] = None
logit_bias: dict = Field(default_factory=dict)
user: Optional[str] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
VLLM_API_PARAMS = {param for param in VLLMConfig.model_fields.keys()}

View File

@@ -0,0 +1,57 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Union
from camel.configs.base_config import BaseConfig
class YiConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Yi API. You can refer to the following link for more details:
https://platform.lingyiwanwu.com/docs/api-reference
Args:
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` or
specifying a particular tool via
{"type": "function", "function": {"name": "some_function"}}
can be used to guide the model to use tools more strongly.
(default: :obj:`None`)
max_tokens (int, optional): Specifies the maximum number of tokens
the model can generate. This sets an upper limit, but does not
guarantee that this number will always be reached.
(default: :obj:`None`)
top_p (float, optional): Controls the randomness of the generated
results. Lower values lead to less randomness, while higher
values increase randomness. (default: :obj:`None`)
temperature (float, optional): Controls the diversity and focus of
the generated results. Lower values make the output more focused,
while higher values make it more diverse. (default: :obj:`0.3`)
stream (bool, optional): If True, enables streaming output.
(default: :obj:`None`)
"""
tool_choice: Optional[Union[dict[str, str], str]] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
stream: Optional[bool] = None
YI_API_PARAMS = {param for param in YiConfig.model_fields.keys()}

View File

@@ -0,0 +1,70 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
from typing import Optional, Sequence, Union
from camel.configs.base_config import BaseConfig
class ZhipuAIConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using OpenAI
compatibility
Reference: https://open.bigmodel.cn/dev/api#glm-4v
Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`None`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`None`)
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`None`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
tool_choice (Union[dict[str, str], str], optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"required"` means the
model must call one or more tools. Specifying a particular tool
via {"type": "function", "function": {"name": "my_function"}}
forces the model to call that tool. :obj:`"none"` is the default
when no tools are present. :obj:`"auto"` is the default if tools
are present.
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
stream: Optional[bool] = None
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
tool_choice: Optional[Union[dict[str, str], str]] = None
ZHIPUAI_API_PARAMS = {param for param in ZhipuAIConfig.model_fields.keys()}

View File

@@ -0,0 +1,19 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .alpaca_collector import AlpacaDataCollector
from .base import BaseDataCollector
from .sharegpt_collector import ShareGPTDataCollector
__all__ = ["BaseDataCollector", "AlpacaDataCollector", "ShareGPTDataCollector"]

View File

@@ -0,0 +1,127 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Self
from camel.agents import ChatAgent
from camel.data_collector.base import BaseDataCollector
from camel.messages import AlpacaItem, BaseMessage
from camel.schemas import OpenAISchemaConverter
# ruff: noqa: E501
DEFAULT_CONVERTER_PROMPTS = """
Extract key entities and attributes from the conversations
and convert them into a structured JSON format.
For example:
Instruction: You are a helpful assistant.
User: When is the release date of the video game Portal?
Assistant: The release date of the video game Portal is October 9.
Your output should be:
{
"instruction": "You are a helpful assistant. When is the release date of the video game Portal?",
"input": "",
"output": "The release date of the video game Portal is October 9."
}
"""
class AlpacaDataCollector(BaseDataCollector):
def __init__(self) -> None:
super().__init__()
self.system_message: Optional[BaseMessage] = None
self.agent_name: Optional[str] = None
def record(
self,
agent: Union[List[ChatAgent], ChatAgent],
) -> Self:
r"""Inject an agent into the data collector.
Args:
agent (Union[List[ChatAgent], ChatAgent]):
The agent to inject.
"""
if not self.agent_name:
_agent = agent if isinstance(agent, ChatAgent) else agent[0]
self.agent_name = _agent.role_name
self.system_message = _agent._system_message
super().record(agent)
return self
def convert(self) -> Dict[str, Any]:
r"""Convert the collected data into a dictionary."""
if self.agent_name is None:
raise ValueError("No agent injected")
history = self.get_agent_history(self.agent_name)
if not history:
raise ValueError("No data collected.")
# Validate and process history
if len(history) == 3 and history[0].role == "system":
history = history[1:] # Ignore the system message.
elif len(history) != 2:
raise ValueError(
f"AlpacaDataCollector only supports one message pair, but "
f"got {len(history)}"
)
input_message, output_message = history
instruction = (
self.system_message.content if self.system_message else ""
) + str(input_message.message)
data = {
"instruction": instruction,
"input": "",
"output": output_message.message,
}
self.data.append(data)
return data
def llm_convert(
self,
converter: Optional[OpenAISchemaConverter] = None,
prompt: Optional[str] = None,
) -> Dict[str, str]:
r"""Convert collected data using an LLM schema converter.
Args:
converter (Optional[OpenAISchemaConverter], optional):
The converter to use. (default: :obj:`OpenAISchemaConverter`)
prompt (Optional[str], optional): Prompt to guide the conversion.
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`)
Returns:
Dict[str, str]: The converted data.
Raises:
ValueError: If no agent is injected or data cannot be collected.
"""
prompt = prompt or DEFAULT_CONVERTER_PROMPTS
converter = converter or OpenAISchemaConverter()
system = self.system_message.content if self.system_message else ""
context = [f"Instruction: {system}\n"]
for message in self.get_agent_history(str(self.agent_name)):
if message.role == "user":
context.append(f"User: {message.message}\n")
else:
context.append(f"{message.name}: {message.message}\n")
return converter.convert(
"\n".join(context), AlpacaItem, prompt=prompt
).model_dump()

View File

@@ -0,0 +1,211 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from uuid import UUID
from typing_extensions import Self
from camel.agents import ChatAgent
class CollectorData:
def __init__(
self,
id: UUID,
name: str,
role: Literal["user", "assistant", "system", "tool"],
message: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
) -> None:
r"""Create a data item store information about a message.
Used by the data collector.
Args:
id (UUID): The id of the message.
name (str): The name of the agent.
role (Literal["user", "assistant", "system", "function"]):
The role of the message.
message (Optional[str], optional): The message.
(default: :obj:`None`)
function_call (Optional[Dict[str, Any]], optional):
The function call. (default: :obj:`None`)
Raises:
ValueError: If the role is not supported.
ValueError: If the role is system and function call is provided.
ValueError: If neither message nor function call is provided.
"""
if role not in ["user", "assistant", "system", "tool"]:
raise ValueError(f"Role {role} not supported")
if role == "system" and function_call:
raise ValueError("System role cannot have function call")
if not message and not function_call:
raise ValueError(
"Either message or function call must be provided"
)
self.id = id
self.name = name
self.role = role
self.message = message
self.function_call = function_call
@staticmethod
def from_context(name, context: Dict[str, Any]) -> "CollectorData":
r"""Create a data collector from a context.
Args:
name (str): The name of the agent.
context (Dict[str, Any]): The context.
Returns:
CollectorData: The data collector.
"""
return CollectorData(
id=uuid.uuid4(),
name=name,
role=context["role"],
message=context["content"],
function_call=context.get("tool_calls", None),
)
class BaseDataCollector(ABC):
r"""Base class for data collectors."""
def __init__(self) -> None:
r"""Create a data collector."""
self.history: List[CollectorData] = []
self._recording = False
self.agents: List[Tuple[str, ChatAgent]] = []
self.data: List[Dict[str, Any]] = []
def step(
self,
role: Literal["user", "assistant", "system", "tool"],
name: Optional[str] = None,
message: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
) -> Self:
r"""Record a message.
Args:
role (Literal["user", "assistant", "system", "tool"]):
The role of the message.
name (Optional[str], optional): The name of the agent.
(default: :obj:`None`)
message (Optional[str], optional): The message to record.
(default: :obj:`None`)
function_call (Optional[Dict[str, Any]], optional):
The function call to record. (default: :obj:`None`)
Returns:
Self: The data collector.
"""
name = name or role
self.history.append(
CollectorData(
id=uuid.uuid4(),
name=name,
role=role,
message=message,
function_call=function_call,
)
)
return self
def record(
self,
agent: Union[List[ChatAgent], ChatAgent],
) -> Self:
r"""Record agents.
Args:
agent (Union[List[ChatAgent], ChatAgent]):
The agent(s) to inject.
"""
if not isinstance(agent, list):
agent = [agent]
for a in agent:
name = a.role_name
if not name:
name = f"{a.__class__.__name__}_{len(self.agents)}"
if name in [n for n, _ in self.agents]:
raise ValueError(f"Name {name} already exists")
self.agents.append((name, a))
return self
def start(self) -> Self:
r"""Start recording."""
self._recording = True
return self
def stop(self) -> Self:
r"""Stop recording."""
self._recording = False
return self
@property
def recording(self) -> bool:
r"""Whether the collector is recording."""
return self._recording
def reset(self, reset_agents: bool = True):
r"""Reset the collector.
Args:
reset_agents (bool, optional):
Whether to reset the agents. Defaults to True.
"""
self.history = []
if reset_agents:
for _, agent in self.agents:
agent.reset()
@abstractmethod
def convert(self) -> Any:
r"""Convert the collected data."""
pass
@abstractmethod
def llm_convert(self, converter: Any, prompt: Optional[str] = None) -> Any:
r"""Convert the collected data."""
pass
def get_agent_history(self, name: str) -> List[CollectorData]:
r"""Get the message history of an agent.
Args:
name (str): The name of the agent.
Returns:
List[CollectorData]: The message history of the agent
"""
if not self.history:
for _name, agent in self.agents:
if _name == name:
return [
CollectorData.from_context(name, dict(i))
for i in agent.memory.get_context()[0]
]
return [msg for msg in self.history if msg.name == name]

View File

@@ -0,0 +1,216 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
from pydantic import BaseModel
from typing_extensions import Self
from camel.agents import ChatAgent
from camel.data_collector.base import BaseDataCollector
from camel.messages import BaseMessage
from camel.messages.conversion.conversation_models import (
ShareGPTConversation,
ShareGPTMessage,
)
from camel.schemas import OpenAISchemaConverter
from camel.toolkits import FunctionTool
FROM_HASH = {
"human": "human",
"gpt": "gpt",
"observation": "human",
"function_call": "gpt",
}
# ruff: noqa: E501
DEFAULT_CONVERTER_PROMPTS = """
Extract key entities and attributes from the conversations
and convert them into a structured JSON format.
For example:
System: You are a helpful assistant
Tools: [{"name": "get_release_date", "arguments": ["Portal"]}]
User: When is the release date of the video game Portal?
Assistant: The release date of the video game Portal is October 9, 2007.
Your output should be:
{
"system": "You are a helpful assistant",
"tools": "[{"name": "get_release_date", "arguments": ["Portal"]}]",
"conversations": [
{"from": "human", "value": "When is the release date of the video game Portal?"},
{"from": "gpt", "value": "The release date of the video game Portal is October 9, 2007."}
]
}
"""
class ConversationItem(BaseModel):
from_: Literal["human", "gpt", "function_call", "observation"]
value: str
class Config:
fields: ClassVar[Dict[str, str]] = {"from_": "from"}
extra = "forbid"
class ShareGPTData(BaseModel):
system: str
tools: str
conversations: List[ConversationItem]
class Config:
extra = "forbid"
class ShareGPTDataCollector(BaseDataCollector):
def __init__(self) -> None:
super().__init__()
self.system_message: Optional[BaseMessage] = None
self.agent_name: Optional[str] = None
self.tools: List[FunctionTool] = []
def record(
self,
agent: Union[List[ChatAgent], ChatAgent],
) -> Self:
r"""Inject an agent into the data collector."""
if not self.agent_name:
_agent = agent if isinstance(agent, ChatAgent) else agent[0]
self.agent_name = _agent.role_name
self.system_message = _agent._system_message
self.tools += list(_agent.tool_dict.values())
super().record(agent)
return self
def convert(self) -> Dict[str, Any]:
r"""Convert the collected data into a dictionary."""
if self.agent_name is None:
raise ValueError("No agent injected")
history = self.get_agent_history(self.agent_name)
if not history:
raise ValueError("No data collected.")
data = dict(
system=self.system_message.content if self.system_message else "",
tools=json.dumps(
[t.get_openai_tool_schema()["function"] for t in self.tools]
),
ensure_ascii=False,
conversations=[],
)
conversations: List[Any] = []
for _data in history:
role, message = _data.role, _data
if role == "user":
conversations.append(
{"from": "human", "value": message.message}
)
elif role == "assistant":
if message.function_call:
conversations.append(
{
"from": "function_call",
"value": json.dumps(
message.function_call, ensure_ascii=False
),
}
)
else:
conversations.append(
{"from": "gpt", "value": message.message}
)
elif role == "function" or role == "tool":
conversations.append(
{
"from": "observation",
"value": json.dumps(
message.message, ensure_ascii=False
), # type: ignore[attr-defined]
}
)
data["conversations"] = conversations
self.data.append(data)
return data
def llm_convert(
self,
converter: Optional[OpenAISchemaConverter] = None,
prompt: Optional[str] = None,
) -> Dict[str, Any]:
r"""Convert collected data using an LLM schema converter.
Args:
converter (Optional[OpenAISchemaConverter], optional):
The converter to use. (default: :obj:`OpenAISchemaConverter`)
prompt (Optional[str], optional): Prompt to guide the conversion.
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`)
Returns:
Dict[str, str]: The converted data.
Raises:
ValueError: If no agent is injected or data cannot be collected.
"""
prompt = prompt or DEFAULT_CONVERTER_PROMPTS
converter = converter or OpenAISchemaConverter()
system = self.system_message.content if self.system_message else ""
context = [f"System: {system}\n"]
context.append(
"Tools: "
+ json.dumps(
[t.get_openai_tool_schema()["function"] for t in self.tools],
ensure_ascii=False,
)
)
for _data in self.get_agent_history(str(self.agent_name)):
role, message = _data.role, _data
prefix = (
f"{role}: " if role != "user" else "User: " + f"{_data.name}: "
)
if message.function_call:
context.append(
prefix
+ json.dumps(message.function_call, ensure_ascii=False)
)
elif role == "function" or role == "tool":
context.append(
prefix + json.dumps(message.message, ensure_ascii=False)
) # type: ignore[attr-defined]
else:
context.append(prefix + str(message.message))
return converter.convert(
"\n".join(context), ShareGPTData, prompt
).model_dump()
@staticmethod
def to_sharegpt_conversation(data: Dict[str, Any]) -> ShareGPTConversation:
messages = [
ShareGPTMessage(from_="system", value=data["system"]) # type: ignore[call-arg]
]
for item in data["conversations"]:
messages.append(
ShareGPTMessage( # type: ignore[call-arg]
from_=FROM_HASH[item["from"]],
value=item["value"],
)
)
return ShareGPTConversation(root=messages)

23
camel/datagen/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .cot_datagen import CoTDataGenerator
from .self_improving_cot import SelfImprovingCoTPipeline
from .self_instruct import SelfInstructPipeline
__all__ = [
"CoTDataGenerator",
"SelfInstructPipeline",
"SelfImprovingCoTPipeline",
]

View File

@@ -0,0 +1,448 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from datetime import datetime
from typing import Annotated, Dict, Optional, Union
from pydantic import BaseModel, Field, confloat
from camel.agents import ChatAgent
from camel.logger import get_logger
# Get a logger for this module
logger = get_logger('CoTDataGenerator')
class AgentResponse(BaseModel):
r"""Model for structured agent responses.
A Pydantic model class that represents structured responses from agents,
including a similarity score that measures the quality of the response.
Args:
score (float): A similarity score between 0 and 1 that compares the
current answer to the correct answer. Must be within the range
[0, 1].
"""
score: Annotated[float, confloat(ge=0, le=1)] = Field(
...,
description="""Similarity score between 0 and 1
comparing current answer to correct answer""",
)
class VerificationResponse(BaseModel):
r"""Model for structured verification responses.
A Pydantic model class that represents verification results from agents,
indicating whether an answer is correct or not.
Args:
is_correct (bool): Boolean indicating if the answer is correct.
"""
is_correct: bool = Field(
...,
description="Boolean indicating if the answer is correct",
)
class CoTDataGenerator:
r"""Class for generating and managing data through chat agent interactions.
This module implements a sophisticated Chain of Thought data generation
system that combines several key algorithms to produce high-quality
reasoning paths. Methods implemented:
1. Monte Carlo Tree Search (MCTS)
2. Binary Search Error Detection
3. Dual-Agent Verification System
4. Solution Tree Management
Args:
chat_agent (Optional[ChatAgent]): Optional single agent
for both tasks (legacy mode). (default::obj:`None`)
generator_agent (Optional[ChatAgent]): Optional specialized agent for
answer generation. (default::obj:`None`)
verifier_agent (Optional[ChatAgent]): Optional specialized agent for
answer verification. (default::obj:`None`)
golden_answers (Dict[str, str]): Dictionary containing pre-defined
correct answers for validation and comparison. Required for answer
verification.
search_limit (int): Maximum number of search iterations allowed.
(default::obj:`100`)
"""
def __init__(
self,
chat_agent: Optional[ChatAgent] = None,
*,
generator_agent: Optional[ChatAgent] = None,
verifier_agent: Optional[ChatAgent] = None,
golden_answers: Dict[str, str],
search_limit: int = 100,
):
r"""Initialize the CoTDataGenerator.
This constructor supports both single-agent and dual-agent modes:
1. Single-agent mode (legacy): Pass a single chat_agent that will be
used for both generation and verification.
2. Dual-agent mode: Pass separate generator_agent and verifier_agent
for specialized tasks.
Args:
chat_agent (Optional[ChatAgent]): Optional single agent for both
tasks (legacy mode). (default::obj:`None`)
generator_agent (Optional[ChatAgent]): Optional specialized agent
for answer generation. (default::obj:`None`)
verifier_agent (Optional[ChatAgent]): Optional specialized agent
for answer verification. (default::obj:`None`)
golden_answers (Dict[str, str]): Dictionary containing pre-defined
correct answers for validation and comparison. Required for
answer verification.
search_limit (int): Maximum number of search iterations allowed.
(default::obj:`100`)
"""
if chat_agent is not None:
if generator_agent is not None or verifier_agent is not None:
raise ValueError(
"Cannot specify both chat_agent \
and generator/verifier agents"
)
self.generator_agent = chat_agent
self.verifier_agent = chat_agent
else:
if generator_agent is None or verifier_agent is None:
raise ValueError(
"Must specify either chat_agent or both generator and "
"verifier agents"
)
self.generator_agent = generator_agent
self.verifier_agent = verifier_agent
self.golden_answers = golden_answers
self.search_limit = search_limit
self.solution_tree: Dict[str, Dict[str, Union[str, int]]] = {}
logger.info(
"CoTDataGenerator initialized with search_limit=%d", search_limit
)
def get_answer(self, question: str, context: str = "") -> str:
r"""Get an answer from the chat agent for a given question.
Args:
question (str): The question to ask.
context (str): Additional context for the question.
(default::obj:`""`)
Returns:
str: The generated answer.
"""
prompt = f"""
Please think step by step and solve this problem: {question}
Existing content: {context}
Requirements:
1. Analyze the problem requirements
2. List the steps to solve the problem
3. Execute the solution process
4. Provide the final answer
Please explain the thought process of each step in detail.
"""
self.generator_agent.reset()
response = self.generator_agent.step(prompt)
answer = response.msgs[0].content
logger.info("AI thought process:\n%s", answer)
return answer
def verify_answer(self, question: str, answer: str) -> bool:
r"""Verify if a generated answer is semantically equivalent to
the golden answer for a given question.
Args:
question (str): The question being answered.
answer (str): The answer to verify.
Returns:
bool: True if the answer matches the golden answer based on
semantic equivalence (meaning the core content and meaning are
the same, even if the exact wording differs).
False in the following cases:
- If the provided question doesn't exist in the golden answers
- If the answer's meaning differs from the golden answer
"""
golden_answer = self.golden_answers.get(question)
if not golden_answer:
raise ValueError(
f"No golden answer found for question: {question}"
)
prompt = (
f"Question: {question}\n"
f"Student Answer: {answer}\n"
f"Correct Answer: {golden_answer}\n"
"Is the student's answer correct? Please respond with 'true' or "
"'false' only."
)
self.verifier_agent.reset()
response = self.verifier_agent.step(
prompt, response_format=VerificationResponse
)
is_correct = response.msgs[0].parsed.is_correct # type:ignore [union-attr]
logger.info("Answer verification result: %s", is_correct)
return is_correct
def monte_carlo_tree_search(
self, question: str, partial_solution: str = ""
) -> float:
r"""Perform Monte Carlo Tree Search to find the best solution.
Process:
a. Selection: Choose promising partial solutions based on previous
scores
b. Expansion: Generate new solution steps using the generator agent
c. Simulation: Evaluate solution quality using similarity scores
d. Backpropagation: Update solution tree with new findings
Args:
question (str): The question to solve.
partial_solution (str): The current partial solution.
(default::obj:`""`)
Returns:
float: The similarity score between the current
solution and golden answer.
"""
if question not in self.golden_answers:
raise ValueError(
f"No golden answer found for question: {question}"
)
golden_answer = self.golden_answers[question]
prompt = (
f"Please evaluate this solution and "
f"give a score between 0-1:\n"
f"Question: {question}\n"
f"Solution: {partial_solution}\n"
f"Correct answer: {golden_answer}\n"
f"Return a JSON object with a single field 'score' containing "
f"a float between 0 and 1, like this: {{'score': 0.85}}\n"
)
self.generator_agent.reset()
response = self.generator_agent.step(
prompt, response_format=AgentResponse
)
agent_response = response.msgs[0].parsed.score # type: ignore [union-attr]
return agent_response
def binary_search_error(self, question: str, solution: str) -> int:
r"""Use binary search to locate the first error in the solution.
This method splits the solution into sentences using both English and
Chinese sentence delimiters and performs binary search to find the
first error.
Args:
question (str): The question being solved.
solution (str): The complete solution to analyze.
Returns:
int: The position of the first error found in the solution.
Returns -1. If no errors are found (all sentences are correct).
"""
logger.info("Starting binary search for error location")
# Split by both English period and Chinese period
sentences = [
s.strip()
for s in solution.replace('', '.').split('.')
if s.strip()
]
# First check if the entire solution is correct
if self.verify_answer(question, solution):
return -1
left, right = 0, len(sentences)
while left < right:
mid = (left + right) // 2
partial_solution = '. '.join(sentences[:mid]) + '.'
logger.info("Checking solution fragment:\n%s", partial_solution)
# Verify if the current part is correct
is_correct = self.verify_answer(question, partial_solution)
if is_correct:
left = mid + 1
else:
right = mid
logger.info("First error position found: sentence %d", left)
return left
def solve(self, question: str) -> str:
r"""Solve a question using a multi-step approach.
The solution process follows these steps:
1. Try to solve directly - if correct, return the solution
2. If not correct, use Monte Carlo Tree Search to find a good solution
3. If the solution isn't perfect, use binary search to locate errors
4. Generate a new solution based on the correct part
Args:
question (str): The question to solve.
Returns:
str: The best solution found.
"""
# 1. Try direct solution first
solution = self.get_answer(question)
if self.verify_answer(question, solution):
logger.info("Initial solution is correct")
return solution
# 2. If direct solution fails, try Monte Carlo Tree Search
# to find a solution with high similarity score
best_solution = ""
best_score: float = 0.0
for i in range(self.search_limit):
# Generate new answer
current_solution = self.get_answer(question, best_solution)
# Evaluate solution similarity score
prompt = (
f"Please evaluate this solution and "
f"give a score between 0-1:\n"
f"Question: {question}\n"
f"Solution: {current_solution}\n"
f"Correct answer: {self.golden_answers.get(question, '')}\n"
f"Return a JSON object with a single field 'score' containing "
f"a float between 0 and 1, like this: {{'score': 0.85}}\n"
)
self.generator_agent.reset()
response = self.generator_agent.step(prompt)
try:
response = self.generator_agent.step(
prompt, response_format=AgentResponse
)
agent_response = response.msgs[0].parsed.score # type: ignore [union-attr]
score = agent_response
# Exit early if we find a very good solution (score > 0.9)
if score > 0.9:
logger.info(
"Found excellent solution with score %.2f. "
"Stopping search early.",
score,
)
return current_solution
if score > best_score:
best_score = score
best_solution = current_solution
logger.info(
"Current search progress: %d/%d, best score: %.2f",
i + 1,
self.search_limit,
best_score,
)
except Exception as e:
logger.error("Error parsing agent response: %s", str(e))
continue
# 3. If the answer is not completely correct,
# use binary search to locate the error
error_pos = self.binary_search_error(question, best_solution)
# If no errors found (error_pos == -1), return the current solution
if error_pos == -1:
logger.info("No specific errors found in the solution")
return best_solution
# 4. Generate new solution based on correct part
correct_part = '. '.join(best_solution.split('. ')[:error_pos]) + '.'
final_solution = self.get_answer(question, correct_part)
self.solution_tree[question] = {
"solution": final_solution,
"error_position": error_pos,
}
return final_solution
def import_qa_from_json(self, data: Union[str, Dict[str, str]]) -> bool:
r"""Import question and answer data from either a JSON file or a
dictionary.
Args:
data (Union[str, Dict[str, str]]): Either a path to a JSON file
containing QA pairs or a dictionary of question-answer pairs.
If a string is provided, it's treated as a file path.
The expected format is:
{"question1": "answer1",
"question2": "answer2",
...}
Returns:
bool: True if import was successful, False otherwise.
"""
try:
if isinstance(data, str):
logger.info("Loading QA pairs from file: %s", data)
with open(data, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
else:
logger.info("Loading QA pairs from provided dictionary")
qa_data = data
# Validate the data format
if not isinstance(qa_data, dict):
logger.error("Invalid data format: expected dictionary")
return False
# Update golden answers
self.golden_answers.update(qa_data)
logger.info("Successfully imported %d QA pairs", len(qa_data))
return True
except Exception as e:
logger.error("Error importing QA data: %s", str(e))
return False
def export_solutions(self, filepath: str = 'solutions.json') -> None:
r"""Export the solution process and results to a JSON file.
Exports the solution tree, golden answers,
and export timestamp to a JSON file.
The exported data includes:
- solutions: The solution tree
with intermediate steps
- golden_answers: The reference answers used for verification
- export_time: ISO format timestamp of the export
Args:
filepath (str, optional): Path where the JSON file will be saved.
(default::obj:`'solutions.json'`)
Returns:
None: The method writes to a file and logs the result but does not
return any value.
"""
export_data = {
"solutions": self.solution_tree,
"golden_answers": self.golden_answers,
"export_time": datetime.now().isoformat(),
}
try:
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(export_data, f, ensure_ascii=False, indent=2)
logger.info(f"Solutions exported successfully to {filepath}")
except Exception as e:
logger.error(f"Error exporting solutions: {e!s}")

View File

@@ -0,0 +1,20 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .evol_instruct import EvolInstructPipeline
__all__ = [
'EvolInstructPipeline',
'MathEvolInstructTemplates',
]

View File

@@ -0,0 +1,424 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
import time
from concurrent.futures import ThreadPoolExecutor
from math import ceil
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from tqdm import tqdm
from camel.agents import ChatAgent
from camel.datagen.evol_instruct.scorer import BaseScorer, GeneralScorer
from camel.datagen.evol_instruct.templates import EvolInstructTemplates
from camel.logger import get_logger
logger = get_logger(__name__)
class EvolInstructPipeline:
r"""Pipeline for evolving prompts using the Evol-Instruct methodology.
Supports custom templates defining evolution strategies and methods. The
pipeline leverages language models to iteratively refine prompts through
specified evolution strategies.
Args:
templates (Type[EvolInstructTemplates]): Template class containing
evolution strategy and method definitions. Must provide
`EVOL_METHODS` and `STRATEGY` attributes.
(default: :obj:`EvolInstructTemplates`)
agent (Optional[ChatAgent]): Chat agent instance for LLM interaction.
If :obj:`None`, initializes with a default ChatAgent.
(default: :obj:`None`)
"""
def __init__(
self,
templates: Type = EvolInstructTemplates,
agent: Optional[ChatAgent] = None,
) -> None:
r"""Initialize pipeline with templates and language model agent.
Args:
templates (Type[EvolInstructTemplates]): Template class containing
evolution strategy configurations.
(default: :obj:`EvolInstructTemplates`)
agent (Optional[ChatAgent]): Preconfigured chat agent instance.
Creates a default ChatAgent if not provided.
(default: :obj:`None`)
"""
self.templates = templates
self.agent = agent or ChatAgent()
def _resolve_evolution_method(self, method_key: str) -> str:
r"""Resolve evolution method key to concrete implementation.
Args:
method_key (str): Input method identifier. Can be:
- Direct method key from templates.EVOL_METHODS
- Strategy name from templates.STRATEGY keys
Returns:
str: Resolved method key from EVOL_METHODS
"""
if method_key in self.templates.EVOL_METHODS:
return method_key
if method_key.upper() in self.templates.STRATEGY:
strategy = self.templates.STRATEGY[method_key.upper()]
strategy_methods = strategy["methods"]
return random.choice(strategy_methods)
logger.warning(
f"Invalid evolution method: {method_key}. "
f"Using random selection."
)
return random.choice(list(self.templates.EVOL_METHODS))
def _get_evolution_methods(
self,
method: Union[str, List[str]],
num_generations: int = 2,
) -> List[str]:
r"""Get list of evolution methods based on input specification.
Args:
method (Union[str, List[str]]): Specification for method selection.
Can be:
- Strategy name for methods from that strategy
- Specific method name
- List of method specifications
num_generations (int): Number of methods to return.
Returns:
List[str]: List of resolved method names
"""
candidate_methods = []
if isinstance(method, list):
for method_spec in method:
candidate_methods.append(
self._resolve_evolution_method(method_spec)
)
elif isinstance(method, str):
if method.upper() in self.templates.STRATEGY:
strategy = self.templates.STRATEGY[method.upper()]
candidate_methods = strategy["methods"]
else:
candidate_methods = [self._resolve_evolution_method(method)]
# Remove duplicates while preserving order
unique_candidates = []
for method_name in candidate_methods:
if method_name not in unique_candidates:
unique_candidates.append(method_name)
if len(unique_candidates) >= num_generations:
methods = random.sample(unique_candidates, num_generations)
else:
methods = unique_candidates.copy()
while len(methods) < num_generations:
methods.append(random.choice(unique_candidates))
return methods
def _generate_single_evolution(
self,
prompt: str,
method: str,
return_method: bool = False,
) -> Tuple[str, str]:
r"""Generate a single evolved prompt from a seed prompt.
Args:
prompt (str): The seed prompt to evolve.
method (str): The evolution method key to use.
return_method (bool): If True, returns method along with prompt.
Returns:
Tuple[str, str]: Evolved prompt and method
"""
resolved_method = self._resolve_evolution_method(method)
# Find strategy containing the resolved method
strategy_key = None
for strategy, group in self.templates.STRATEGY.items():
if resolved_method in group["methods"]:
strategy_key = strategy
break
if strategy_key is None:
strategy_key = random.choice(list(self.templates.STRATEGY.keys()))
strategy = self.templates.STRATEGY[strategy_key]
instruction_template = strategy["meta_instruction"]
instruction = instruction_template.format(
method=self.templates.EVOL_METHODS.get(
resolved_method,
random.choice(list(self.templates.EVOL_METHODS.values())),
),
prompt=prompt,
)
self.agent.reset()
response = self.agent.step(instruction)
evolved_prompt = response.msgs[0].content.strip()
if return_method:
return (evolved_prompt, resolved_method)
else:
return (evolved_prompt, "")
def _generate_multiple_evolutions(
self,
prompt: str,
method: Union[str, List[str]],
num_generations: int = 2,
keep_original: bool = True,
num_threads: int = 10,
) -> List[Tuple[str, str]]:
r"""Generate multiple evolved versions of a prompt.
Args:
prompt (str): Seed prompt to evolve.
method (Union[str, List[str]]): Evolution method specification.
num_generations (int): Candidates to generate per iteration.
keep_original (bool): Whether to keep the original prompt.
num_threads (int): Number of threads for parallel processing.
Returns:
List[Tuple[str, str]]: List of (evolved_prompt, method) pairs
"""
results = [(prompt, "original")] if keep_original else []
if isinstance(method, list) and len(method) == num_generations:
candidate_methods = method
else:
candidate_methods = self._get_evolution_methods(
method=method, num_generations=num_generations
)
def _process_single_method(method_name: str) -> Tuple[str, str]:
return self._generate_single_evolution(
prompt, method_name, return_method=True
)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
evolved_results = list(
executor.map(_process_single_method, candidate_methods)
)
results.extend(evolved_results)
return results
def _generate_iterative_evolutions(
self,
prompt: str,
evolution_spec: Union[str, List[Union[str, List[str]]]],
num_generations: int = 2,
num_iterations: Optional[int] = None,
keep_original: bool = True,
scorer: Optional[BaseScorer] = None,
num_threads: int = 10,
) -> Dict[int, List[Dict[str, Any]]]:
r"""Generate iterative evolutions of a prompt with scoring.
Args:
prompt (str): Seed prompt to evolve.
evolution_spec (Union[str, List[Union[str, List[str]]]]):
Evolution method specification.
If a list is provided and num_iterations is None, then
num_iterations is set to the length of the list.
num_generations (int): Candidates to generate per iteration.
num_iterations (Optional[int]): Number of evolution iterations.
Defaults to the length of evolution_spec.
keep_original (bool): Include original prompt in results.
scorer (Optional[BaseScorer]): Scoring model for candidate.
num_threads (int): Number of threads for parallel processing.
Returns:
Dict[int, List[Dict[str, Any]]]: Evolution results per iteration,
where each candidate is represented as a dict with keys:
"instruction", "method", and "scores".
"""
if num_iterations is None:
if isinstance(evolution_spec, list):
num_iterations = len(evolution_spec)
else:
num_iterations = 1
results = {}
current_prompt = prompt
scorer = scorer or GeneralScorer()
for iteration in range(num_iterations):
if isinstance(evolution_spec, list):
if iteration < len(evolution_spec):
iteration_spec = evolution_spec[iteration]
else:
iteration_spec = evolution_spec[-1]
else:
iteration_spec = evolution_spec
batch_results = self._generate_multiple_evolutions(
prompt=current_prompt,
method=iteration_spec,
num_generations=num_generations,
keep_original=False,
num_threads=num_threads,
)
scored_results = []
for candidate, method_used in batch_results:
scores = scorer.score(current_prompt, candidate)
scored_results.append(
{
"instruction": candidate,
"method": method_used,
"scores": scores,
}
)
best_index = max(
range(len(scored_results)),
key=lambda i: sum(
cast(Dict[str, int], scored_results[i]["scores"]).values()
),
)
best_candidate = cast(
str, scored_results[best_index]["instruction"]
)
if keep_original:
results[iteration] = [
{
"instruction": current_prompt,
"method": "original",
"scores": {},
},
*scored_results,
]
else:
results[iteration] = scored_results
current_prompt = best_candidate
return results
def generate(
self,
prompts: List[str],
evolution_spec: Union[str, List[Union[str, List[str]]]],
num_generations: int = 2,
num_iterations: Optional[int] = None,
keep_original: bool = True,
scorer: Optional[BaseScorer] = None,
num_chunks: int = 1,
retry_limit: int = 3,
retry_delay: float = 1.0,
num_threads: int = 10,
) -> List[Dict[int, List[Dict[str, Any]]]]:
r"""Evolve a batch of prompts through iterative refinement.
Args:
prompts (List[str]): Seed prompts to evolve.
evolution_spec (Union[str, List[Union[str, List[str]]]]):
Evolution method specification.
If a list is provided and num_iterations is None, then
num_iterations is set to the length of the list.
num_generations (int): Candidates to generate per iteration.
num_iterations (Optional[int]): Number of evolution iterations.
Defaults to the length of evolution_spec.
keep_original (bool): Include original prompts in results.
scorer (Optional[BaseScorer]): Scoring model for candidate.
num_chunks (int): Number of parallel processing chunks.
retry_limit (int): Max retries for failed generations.
retry_delay (float): Delay between retries in seconds.
num_threads (int): Number of threads for parallel processing.
Returns:
List[Dict[int, List[Dict[str, Any]]]]: Evolution results.
"""
if num_iterations is None:
if isinstance(evolution_spec, list):
num_iterations = len(evolution_spec)
else:
num_iterations = 1
evolution_plan: List[List[List[str]]] = []
for _ in prompts:
prompt_plan = []
for iteration in range(num_iterations):
if isinstance(evolution_spec, list):
if iteration < len(evolution_spec):
raw_spec = evolution_spec[iteration]
else:
raw_spec = evolution_spec[-1]
else:
raw_spec = evolution_spec
prompt_plan.append(
self._get_evolution_methods(raw_spec, num_generations)
)
evolution_plan.append(prompt_plan)
def _process_prompt(
args: Tuple[str, List[List[str]]],
) -> Dict[int, List[Dict[str, Any]]]:
prompt, methods = args
retries = 0
while retries <= retry_limit:
try:
return self._generate_iterative_evolutions(
prompt=prompt,
evolution_spec=evolution_spec,
num_generations=num_generations,
num_iterations=num_iterations,
keep_original=keep_original,
scorer=scorer,
num_threads=num_threads,
)
except Exception as e:
retries += 1
if retries <= retry_limit:
logger.warning(
f"Error processing prompt "
f"(attempt {retries}/{retry_limit}): {e!s}"
)
time.sleep(retry_delay)
else:
logger.error("Failed to process prompt.")
return {}
raise RuntimeError("_process_prompt() did not return.")
num_chunks = max(1, min(num_chunks, len(prompts)))
chunk_size = ceil(len(prompts) / num_chunks)
results = []
for chunk_idx in range(0, len(prompts), chunk_size):
chunk = prompts[chunk_idx : chunk_idx + chunk_size]
plan_chunk = evolution_plan[chunk_idx : chunk_idx + chunk_size]
with ThreadPoolExecutor(max_workers=num_threads) as executor:
chunk_results = list(
tqdm(
executor.map(_process_prompt, zip(chunk, plan_chunk)),
total=len(chunk),
)
)
results.extend(chunk_results)
return results

View File

@@ -0,0 +1,166 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from abc import ABC, abstractmethod
from typing import Dict, Optional
from pydantic import BaseModel, Field
from camel.agents import ChatAgent
from camel.logger import get_logger
logger = get_logger(__name__)
class BaseScorer(ABC):
@abstractmethod
def score(
self, reference_prompt: str, candidate_prompt: str
) -> Dict[str, int]:
r"""Compare a candidate prompt against a reference prompt and
return a tuple of scores. The higher the score, the better.
For example, (diversity, difficulty, feasibility).
"""
pass
class MathScorer(BaseScorer):
def __init__(self, agent: Optional[ChatAgent] = None):
self.system_msg = (
"You are an evaluator for math problems. Your task is to compare "
"a new math problem against a reference math problem, and rate it "
"in **four dimensions**, each scored from 1 to 5.\n\n"
"1. Diversity (1-5): How novel is the new problem compared to the "
"reference? 1 = very similar, 5 = completely different.\n"
"2. Difficulty (1-5): Rate the relative difficulty compared to the"
" reference problem. 1 = much less difficult, "
"3 = similar difficulty, 5 = much more difficult.\n"
"3. Validity (1-5): How well-defined and sound is the problem?"
"1 = very vague or flawed, 5 = very clear and rigorous.\n"
"4. Solvability (1-5): How likely is the problem solvable using "
"standard math techniques? 1 = very unsolvable or ambiguous, "
"5 = very clearly solvable.\n\n"
"Respond with a JSON object like: "
"{ \"diversity\": ..., \"difficulty\": ..., "
"\"validity\": ..., \"solvability\": ... }"
)
self.agent = agent or ChatAgent(self.system_msg)
class MathScoreSchema(BaseModel):
diversity: int = Field(
...,
description=(
"Score for the diversity of the math problem "
"compared to the reference"
),
)
difficulty: int = Field(
..., description="Score for the relative difficulty"
)
validity: int = Field(
...,
description="Score for how well-defined and sound the problem is",
)
solvability: int = Field(
...,
description="Score for the solvability of the problem",
)
def score(
self, reference_problem: str, new_problem: str
) -> Dict[str, int]:
r"""Evaluates the new math problem relative to the reference math
problem.
Args:
reference_problem (str): The reference math problem.
new_problem (str): The new or evolved math problem.
Returns:
Dict[str, int]: A dictionary with scores for diversity, difficulty,
validity, and solvability.
"""
query = (
f"Reference problem:\n{reference_problem}\n\n"
f"New problem:\n{new_problem}\n\n"
"Provide scores in JSON format."
)
response = self.agent.step(query, response_format=self.MathScoreSchema)
score_data = json.loads(response.msg.content)
return score_data
class GeneralScorer(BaseScorer):
def __init__(self, agent: Optional[ChatAgent] = None):
self.system_msg = (
"You are an evaluator for problems in various domains. Your task "
"is to compare a new problem against a reference problem, and rate"
" it in **three dimensions**, each scored from 1 to 5.\n\n"
"1. Diversity (1-5): How novel is the new problem compared to the "
"reference? 1 = very similar, 5 = completely different.\n"
"2. Complexity (1-5): Relative to the reference problem. "
"1 = much less complex, 3 = similar complexity, "
"5 = much more complex.\n"
"3. Validity (1-5): How well-defined, meaningful, the problem is."
"1 = vague/flawed, 5 = precise and fully meaningful.\n"
"Respond with a JSON object like: "
"{ \"diversity\": ..., \"complexity\": ..., \"validity\": ... }"
)
self.agent = agent or ChatAgent(self.system_msg)
class GeneralScoreSchema(BaseModel):
diversity: int = Field(
...,
description=(
"Score for the diversity of the problem "
"compared to the reference."
),
)
complexity: int = Field(
...,
description=("Score for the relative complexity of the problem."),
)
validity: int = Field(
...,
description=(
"Score estimating the likelihood that the problem is "
"well-defined."
),
)
def score(
self, reference_problem: str, new_problem: str
) -> Dict[str, int]:
r"""Evaluates the new problem against the reference problem using
structured scoring.
Args:
reference_problem (str): The original problem.
new_problem (str): The evolved or new problem.
Returns:
Dict[str, int]: A dictionary with scores for diversity, complexity,
and validity.
"""
query = (
f"Reference problem:\n{reference_problem}\n\n"
f"New problem:\n{new_problem}\n\n"
"Provide scores in JSON format."
)
response = self.agent.step(
query, response_format=self.GeneralScoreSchema
)
score_data = json.loads(response.msg.content)
return score_data

View File

@@ -0,0 +1,268 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Union
# flake8: noqa
@dataclass(frozen=True)
class BaseEvolInstructTemplates(ABC):
r"""Abstract base class for evolution instruction templates.
This class defines a required structure for prompt transformation templates
- `EVOL_METHODS`: A dictionary mapping method keys to their descriptions.
- `STRATEGY`: A dictionary defining strategies and associated methods.
Subclasses should define concrete templates for specific domains.
"""
@property
@abstractmethod
def EVOL_METHODS(self) -> Dict[str, str]:
r"""A dictionary mapping evolution method keys to their descriptions."""
pass
@property
@abstractmethod
def STRATEGY(self) -> Dict[str, Dict[str, Union[str, List[str]]]]:
r"""A dictionary defining strategies and their corresponding methods."""
pass
# flake8: noqa
@dataclass(frozen=True)
class EvolInstructTemplates(BaseEvolInstructTemplates):
r"""Contains templates for EvolInstruct prompt transformations.
References:
- WizardLM: Empowering Large Language Models to Follow Complex
Instructions
https://arxiv.org/pdf/2304.12244
- eva: Evolving Alignment via Asymmetric Self-Play
https://arxiv.org/abs/2411.00062
"""
# High-level instructions on in-depth/in-breadth evolving
INST_IN_DEPTH = (
"Please act as an expert Prompt Creator.\n"
"Your objective is to rewrite a given prompt into a more complex "
"version to make those large language models (e.g., gemini) a bit "
"harder to handle.\n"
"But the rewritten prompt must be reasonable and must be understood "
"and responded by humans.\n"
"Your rewriting cannot omit the non-text parts such as the table and "
"code in #Given Prompt#, if there is any."
"You should try your best not to make the #Rewritten Prompt# become "
"verbose, "
"The #Rewritten Prompt# should be roughly the similar length or a "
"little bit more than that of #Given Prompt#.\n"
"The #Rewritten Prompt# must sound like a real human user's prompt; "
"DON'T make it like sound machine-generated."
"Specifically, you SHOULD complicate the given prompt using the "
"following method: "
"\n{method}\n"
"The rewritten prompt should reflect meaningful changes across its "
"structure, ensuring the entire sentence feels sufficiently different "
"from the original. "
"Again, make sure the rewritten prompt is more CHALLENGING."
"Respond with your rewritten prompt directly. "
"#Given Prompt#:\n{prompt}\n"
"#Rewritten Prompt#:\n"
).lstrip()
INST_IN_BREADTH = (
"Please act as an expert Prompt Creator.\n"
"Your objective is to generate a brand-new prompt based on the #Given "
"Prompt#. "
"The purpose of this task is to promote diversity and generality of "
"training prompts for language models, helping it practice with "
"varied challenges and perspectives.\n"
"The LENGTH and complexity of the #Created Prompt# should be similar "
"to that of the #Given Prompt#.\n"
"The #Created Prompt# must be reasonable, interpretable, and solvable "
"by humans.\n"
"The #Created Prompt# must sound like a real human user's prompt; "
"DON'T make it sound like machine-generated."
"Follow the method described below to guide your creation:\n"
"{method}\n"
"The created prompt should reflect meaningful changes across its "
"structure, ensuring the entire sentence feels sufficiently different "
"from the original. "
"Respond with your created prompt directly.\n"
"#Given Prompt#:\n{prompt}\n"
"#Created Prompt#:\n"
).lstrip()
# Sub-method instructions (following the eva paper setting)
IN_BREADTH_KEYS = [
'persona',
'shift-in',
'shift-out',
'mix',
'abstract',
]
IN_DEPTH_KEYS = [
'constraints',
'deepening',
'concretizing',
'reasoning',
'expansion',
]
STRATEGY = {
"IN-DEPTH": {
'meta_instruction': INST_IN_DEPTH,
'methods': IN_DEPTH_KEYS,
},
"IN-BREADTH": {
'meta_instruction': INST_IN_BREADTH,
'methods': IN_BREADTH_KEYS,
},
}
EVOL_METHODS = {
"persona": (
"Reframe the #Given Prompt# as if written by a user with a "
"completely different persona, background, or expertise. Adjust "
"the tone, style, phrasing, or anything you feel proper to "
"reflect this change. The changes should make the prompt feel "
"like it was authored by someone entirely new."
),
"shift-in": (
"Shift the high-level idea of the #Given Prompt# to explore a "
"different subdomain or context within the same domain. Ensure "
"the new topic still challenges the model to reason or provide "
"knowledge relevant to the domain."
),
"shift-out": (
"Shift the high-level idea of the #Given Prompt# to a completely "
"different topic in a different setting. The new topic may "
"challenge the model with similar reasoning or contextual "
"understanding but in a novel way."
),
"mix": (
"Combine the high-level concept of the #Given Prompt# with "
"elements from a different domain. Introduce novel scenarios or "
"contexts to create diversity while maintaining relevance to the "
"original idea."
),
"abstract": (
"Turn the #Given Prompt# into a more abstract or generalized "
"version, removing specific details while preserving its intent. "
"Ensure the new prompt encourages broader, principle-driven "
"reasoning."
),
"constraints": (
"Add one or more significant constraints or requirements into the "
"'#Given Prompt#'. The added constraints must meaningfully alter "
"how the model would respond. For example, specify additional "
"rules, contexts, or limitations that demand creative adjustments."
),
"deepening": (
"If the #Given Prompt# contains inquiries about certain issues, "
"increase the depth and breadth of the inquiry. Make the question "
"require a more detailed, multi-layered, or comprehensive response"
". For instance, break the problem into sub-problems or require "
"connections between unrelated concepts."
),
"concretizing": (
"Replace general concepts in the #Given Prompt# with more specific"
" and detailed concepts. Ensure that the change makes the problem "
"more defined and concrete, leaving less room for ambiguity. For "
"example, replace 'a device' with 'a wearable fitness tracker "
"with GPS'."
),
"reasoning": (
"Add one or more reasoning steps into the '#Given Prompt#'. "
"Explicitly rewrite it to demand multi-step reasoning or justify "
"intermediate steps in the solution. For instance, if the original"
" prompt is a simple query, make the response require a "
"step-by-step breakdown of logic or calculations."
),
"expansion": (
"Expand the #Given Prompt# by including additional perspectives, "
"domains, or layers of complexity. For example, if the original "
"prompt focuses on a single scenario, add related scenarios or ask"
" the model to compare different situations."
),
}
# flake8: noqa
@dataclass(frozen=True)
class MathEvolInstructTemplates(BaseEvolInstructTemplates):
r"""Contains templates for MathEvolInstruct prompt transformations."""
# Meta-instructions for in-depth evolving
INST_IN_DEPTH = (
"Please act as a math expert. Your objective is to create a new math "
"problem that is more challenging yet concise than the given math "
"problem. Ensure that the mathematical content (including any "
"equations or figures) is preserved, and rephrase the problem to "
"increase its complexity and depth. The generated problem should be "
"clearly stated, strictly mathematical, and suitable for solving with "
"symbolic computation (e.g., using sympy). You will be given a method "
"to guide your creation. Make sure to follow the method strictly. "
"Consolidate any multiple parts into one integrated question that "
"ask for one definitive answer. Respond with your generated problem "
"directly. "
"#Original Problem#:\n{prompt}\n"
"#Generated Problem#:\n"
).lstrip()
EVOL_METHODS = {
"constraints": (
"Add one or more significant constraints or requirements into the "
"'#Given Prompt#'. The added constraints must meaningfully alter "
"how the model would respond. For example, specify additional "
"rules, contexts, or limitations that demand creative adjustments."
),
"deepening": (
"Increase the difficulty of the #Given Prompt# by integrating "
"additional layers of reasoning and rigor. Refine the problem so "
"that all added difficulty is consolidated into a single coherent "
"question requiring one final answer, avoiding fragmentation into "
"multiple sub-problems."
),
"expansion": (
"Expand the #Given Prompt# by incorporating additional "
"perspectives or layers of complexity into the problem statement. "
"Ensure that the revised problem remains a single, unified "
"question with one final answer, rather than a series of separate "
"sub-questions."
),
"condense": (
"Reformulate the given math problem into a well-structured and "
"formally stated mathematical question.\n"
"- Present the problem in a structured and rigorous mathematical "
"format.\n"
"- Removing unnecessary instructions, explanations, or hints.\n"
"- If the given problem contains several sub-questions, make "
"necessary changes to let the problem could be answered with one "
"number or expression by removing the sub-questions or combining "
"them into one."
),
}
IN_DEPTH_KEYS = ['constraints', 'deepening', 'expansion']
STRATEGY = {
"IN-DEPTH": {
'meta_instruction': INST_IN_DEPTH,
'methods': IN_DEPTH_KEYS,
},
}

View File

@@ -0,0 +1,899 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
import json
import math
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
from camel.agents import ChatAgent
from camel.logger import get_logger
from camel.models.reward import BaseRewardModel, Evaluator
from camel.utils import BatchProcessor, retry_on_error
logger = get_logger(__name__)
class AgentTraceEvaluation(BaseModel):
correctness: float
clarity: float
completeness: float
feedback: str
class RewardTraceEvaluation(BaseModel):
feedback: str
def __init__(self, **data):
# Allow dynamic score fields while ensuring feedback is present
super().__init__(**data)
class Config:
extra = (
"allow" # Allow extra fields for different reward model dimensions
)
class TraceIteration(BaseModel):
iteration: int
trace: str
evaluation: Union[AgentTraceEvaluation, RewardTraceEvaluation]
class ProblemResult(BaseModel):
id: Optional[str] = None
type: Optional[str] = None
problem: str
solution: Optional[str] = None
final_trace: str
agent_evaluate_success: Optional[bool] = None
boxed_answer_success: bool = False
improvement_history: List[TraceIteration]
class SelfImprovingCoTPipeline:
r"""Pipeline for generating self-taught reasoning traces
using the self-improving methodology.
This implements the STaR paper's approach of:
1. Initial reasoning trace generation
2. Self-evaluation
3. Feedback-based improvement
4. Iterative refinement
"""
def __init__(
self,
reason_agent: ChatAgent,
problems: List[Dict],
max_iterations: int = 3,
score_threshold: Union[float, Dict[str, float]] = 0.7,
rejection_sampling_n: Optional[int] = None,
evaluate_agent: Optional[ChatAgent] = None,
reward_model: Optional[BaseRewardModel] = None,
output_path: Optional[str] = None,
few_shot_examples: Optional[str] = None,
batch_size: Optional[int] = None,
max_workers: Optional[int] = None,
solution_pattern: str = r'\\boxed{(.*?)}',
trace_pattern: Optional[str] = None,
):
r"""Initialize the self-improving cot pipeline.
Args:
reason_agent (ChatAgent): The chat agent used for generating and
improving reasoning traces.
problems (List[Dict]): List of problem dictionaries to process.
max_iterations (int, optional): Maximum number of improvement
iterations. If set to `0`, the pipeline will generate an
initial trace without any improvement iterations.
(default: :obj:`3`)
score_threshold (Union[float, Dict[str, float]], optional):
Quality threshold. Can be either a single float value applied
to average score, or a dictionary mapping score dimensions to
their thresholds. For example: {"correctness": 0.8,
"coherence": 0.7}. If using reward model and threshold for a
dimension is not specified, will use the default value 0.7.
(default: :obj:`0.7`)
rejection_sampling_n (int, optional): Specifies the number of
samples to be drawn using the rejection sampling
method, where samples are accepted or rejected based on
a predefined condition to achieve a desired distribution.
(default: :obj: `None`)
evaluate_agent (Optional[ChatAgent]): The chat agent used for
evaluating reasoning traces. (default: :obj:`None`)
reward_model (BaseRewardModel, optional): Model used to evaluate
reasoning traces. If `None`, uses Agent self-evaluation.
(default: :obj:`None`)
output_path (str, optional): Output path for saving traces. If
`None`, results will only be returned without saving to file.
(default: :obj:`None`)
few_shot_examples (str, optional): Examples to use for few-shot
generation. (default: :obj:`None`)
batch_size (int, optional): Batch size for parallel processing.
(default: :obj:`None`)
max_workers (int, optional): Maximum number of worker threads.
(default: :obj:`None`)
solution_pattern (str, optional): Regular expression pattern with
one capture group to extract answers from solution text.
(default: :obj:`r'\\boxed{(.*?)}'`)
trace_pattern (str, optional): Regular expression pattern with one
capture group to extract answers from trace text. If `None`,
uses the same pattern as solution_pattern.
(default: :obj:`None`)
"""
self.reason_agent = reason_agent
self.evaluate_agent = evaluate_agent
self.problems = problems
self.output_path = output_path
self.max_iterations = max_iterations
self.score_threshold = score_threshold
self.rejection_sampling_n = rejection_sampling_n
self.reward_model = reward_model
self.evaluator = (
Evaluator(reward_model=reward_model) if reward_model else None
)
self.reasoning_traces: List[Dict[str, Any]] = []
self.few_shot_examples = few_shot_examples
self.batch_processor = BatchProcessor(max_workers, batch_size)
self.solution_pattern = solution_pattern
self.trace_pattern = (
trace_pattern if trace_pattern is not None else solution_pattern
)
# Initialize output file with empty results if path is specified
if self.output_path:
with open(self.output_path, 'w') as f:
json.dump({'traces': []}, f, indent=2, ensure_ascii=False)
self.lock = threading.Lock()
def safe_write_json(self, file_path, data):
temp_path = file_path + ".tmp"
with open(temp_path, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(temp_path, file_path)
def clean_json(self, data):
if isinstance(data, dict):
return {k: self.clean_json(v) for k, v in data.items()}
elif isinstance(data, list):
return [self.clean_json(v) for v in data]
elif isinstance(data, float) and (
math.isnan(data) or math.isinf(data)
):
return None
return data
async def _batch_process_problems(
self, problems: List[Dict], rationalization: bool
) -> List[ProblemResult]:
r"""Process multiple problems in parallel batches with dynamic sizing.
Args:
problems (List[Dict]): List of problem dictionaries to process.
rationalization (bool): Whether to use rationalization.
Returns:
List[ProblemResult]: List of problem results.
"""
results = []
total_problems = len(problems)
processed = 0
while processed < total_problems:
batch_size = self.batch_processor.batch_size
batch = problems[processed : processed + batch_size]
batch_start_time = time.time()
try:
with ThreadPoolExecutor(
max_workers=self.batch_processor.max_workers
) as executor:
# Create futures with rationalization parameter
futures = [
executor.submit(
self.process_problem,
problem=problem,
rationalization=rationalization,
)
for problem in batch
]
batch_results = []
batch_success = True
for future in as_completed(futures):
try:
result = future.result()
batch_results.append(result)
except Exception as e:
logger.error(f"Error processing problem: {e}")
batch_success = False
continue
results.extend(batch_results)
processed += len(batch)
# Calculate processing time and adjust batch size
processing_time = time.time() - batch_start_time
self.batch_processor.adjust_batch_size(
batch_success, processing_time
)
# Log progress and performance metrics
metrics = self.batch_processor.get_performance_metrics()
logger.info(
f"Processed {processed}/{total_problems} problems "
f"(batch size: {batch_size}, workers: "
f"{metrics['current_workers']}, "
f"CPU: {metrics['current_cpu']:.1f}%, "
f"Memory: {metrics['current_memory']:.1f}%)"
)
except Exception as e:
logger.error(f"Batch processing error: {e}")
self.batch_processor.adjust_batch_size(False)
continue
return results
async def _batch_evaluate_traces(
self,
problems: List[Dict[str, Any]],
traces: List[str],
solutions: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
r"""Evaluate multiple traces in parallel batches with resource
monitoring.
Args:
problems (List[Dict[str, Any]]): List of problem dictionaries
traces (List[str]): List of reasoning traces to evaluate
solutions (Optional[List[str]]): Optional list of solutions
Returns:
List[Dict[str, Any]]: List of evaluation results
"""
if solutions is None:
solutions = ["null"] * len(problems)
results = []
total_traces = len(traces)
processed = 0
while processed < total_traces:
batch_size = self.batch_processor.batch_size
problem_batch = problems[processed : processed + batch_size]
trace_batch = traces[processed : processed + batch_size]
solution_batch = solutions[processed : processed + batch_size]
batch_start_time = time.time()
try:
with ThreadPoolExecutor(
max_workers=self.batch_processor.max_workers
) as executor:
futures = [
executor.submit(
self.evaluate_trace,
problem=problem["problem"],
trace=trace,
solution=solution,
)
for problem, trace, solution in zip(
problem_batch, trace_batch, solution_batch
)
]
batch_results = []
batch_success = True
for future in as_completed(futures):
try:
result = future.result()
batch_results.append(result)
except Exception as e:
logger.error(f"Error evaluating trace: {e}")
batch_success = False
continue
results.extend(batch_results)
processed += len(batch_results)
# Calculate processing time and adjust batch size
processing_time = time.time() - batch_start_time
self.batch_processor.adjust_batch_size(
batch_success, processing_time
)
# Log progress and performance metrics
metrics = self.batch_processor.get_performance_metrics()
logger.info(
f"Evaluated {processed}/{total_traces} traces "
f"(batch size: {batch_size}, workers: "
f"{metrics['current_workers']}, "
f"avg time: {metrics['avg_processing_time']:.2f}s, "
f"error rate: {metrics['error_rate']:.1f}%)"
)
except Exception as e:
logger.error(f"Batch evaluation error: {e}")
self.batch_processor.adjust_batch_size(False)
continue
return results
def _check_score_threshold(self, scores: Dict[str, float]) -> bool:
r"""Check if scores meet the threshold requirements.
Args:
scores (Dict[str, float]): Dictionary of scores for different
dimensions.
Returns:
bool: True if scores meet threshold requirements, False otherwise.
"""
# If score_threshold is a float, apply it to all dimensions
if isinstance(self.score_threshold, float):
return all(
score >= self.score_threshold for score in scores.values()
)
# If score_threshold is a dict, check each dimension with its threshold
# Use 0 as default threshold for unspecified dimensions
if isinstance(self.score_threshold, dict):
for dim, score in scores.items():
threshold = self.score_threshold.get(dim, 0)
if score < threshold:
return False
return True
# If score_threshold is None or invalid type, pass the check
return True
def _generate_feedback(self, scores: Dict[str, float]) -> str:
r"""Generate feedback based on which dimensions need improvement.
Args:
scores (Dict[str, float]): Dictionary of scores for different
dimensions.
Returns:
str: Feedback message indicating which dimensions need improvement.
"""
if isinstance(self.score_threshold, float):
below_threshold = [
dim
for dim, score in scores.items()
if score < self.score_threshold
]
if not below_threshold:
return "All dimensions meet the required threshold"
dims = ", ".join(below_threshold)
return f"Need improvement in: {dims}"
if isinstance(self.score_threshold, dict):
default_threshold = 0
below_threshold = [
dim
for dim, score in scores.items()
if score < self.score_threshold.get(dim, default_threshold)
]
if not below_threshold:
return "All dimensions meet their respective thresholds"
dims = ", ".join(below_threshold)
return f"Need improvement in: {dims}"
# If no threshold set, just list all dimensions and their scores
dims = ", ".join(
f"{dim}: {score:.2f}" for dim, score in scores.items()
)
return f"Current scores - {dims}"
@retry_on_error()
def generate_reasoning_trace(self, problem: str) -> str:
r"""Generate initial reasoning trace for a given problem.
Args:
problem (str): The problem text to generate reasoning for.
Returns:
str: Generated reasoning trace.
"""
self.reason_agent.reset()
few_shot_examples = (
f"Examples: {self.few_shot_examples}"
if self.few_shot_examples
else ""
)
prompt = self.REASONING_TEMPLATE.format(
problem=problem, few_shot_examples=few_shot_examples
)
response = self.reason_agent.step(prompt)
return response.msg.content
@retry_on_error()
def evaluate_trace(
self, problem: str, trace: str, solution: Optional[str] = None
) -> Dict[str, Any]:
r"""Evaluate the quality of a reasoning trace.
Args:
problem (str): The original problem text to evaluate against.
trace (str): The reasoning trace to evaluate.
solution (Optional[str]): The solution to the problem, if provided.
(default: :obj:`None`)
Returns:
Dict[str, Any]: Evaluation results containing:
- scores: Dict of evaluation dimensions and their scores
- feedback: Detailed feedback for improvement
For Agent self-evaluation, the scores will include:
- correctness: Score for logical correctness
- clarity: Score for clarity of explanation
- completeness: Score for completeness of reasoning
For reward model evaluation, the scores will depend on
the model's evaluation dimensions.
"""
self.evaluate_agent.reset() # type: ignore[union-attr]
if self.evaluator:
# Use reward model evaluation
messages = [
{"role": "user", "content": problem},
{"role": "assistant", "content": trace},
]
scores = self.evaluator.evaluate(messages)
# For models that return a single score
if isinstance(scores, (int, float)) or (
isinstance(scores, dict) and len(scores) == 1
):
if isinstance(scores, dict):
score = next(iter(scores.values()))
else:
score = scores
scores_dict = {"overall": score}
return {
**scores_dict,
"feedback": self._generate_feedback(scores_dict),
}
# For models that return multiple dimensions
return {**scores, "feedback": self._generate_feedback(scores)}
else:
# Fallback to original Agent self-evaluation
solution_text = f"Solution: {solution}" if solution else ""
prompt = self.EVALUATION_TEMPLATE.format(
problem=problem, trace=trace, solution=solution_text
)
response = self.evaluate_agent.step( # type: ignore[union-attr]
prompt, response_format=AgentTraceEvaluation
)
if response.msg.parsed is None:
raise AttributeError("Failed to parse evaluation response")
# Convert dict to AgentTraceEvaluation if needed
if isinstance(response.msg.parsed, dict):
evaluation = AgentTraceEvaluation(**response.msg.parsed)
else:
evaluation = response.msg.parsed
return evaluation.model_dump()
@retry_on_error()
def generate_reasoning_trace_rejection(self, problem: str) -> str:
r"""Generate multiple candidate reasoning traces for a problem and
select the best one based on evaluation.
Args:
problem (str): The problem text for generating a reasoning trace.
Returns:
str: The best candidate trace that meets quality criteria, or the
first candidate if none qualify.
"""
few_shot_examples = (
f"Examples: {self.few_shot_examples}"
if self.few_shot_examples
else ""
)
prompt = self.REASONING_TEMPLATE.format(
problem=problem, few_shot_examples=few_shot_examples
)
responses, candidate_traces = None, []
if 'n' in self.reason_agent.model_backend.model_config_dict:
self.reason_agent.model_backend.model_config_dict['n'] = (
self.rejection_sampling_n
)
# Generate multiple candidate traces in one call using parameter n
responses = self.reason_agent.step(prompt)
# Extract cancidate traces
candidate_traces = [choice.content for choice in responses.msgs]
else:
sampling_n = (
self.rejection_sampling_n
if self.rejection_sampling_n is not None
else 1
)
for _i in range(sampling_n):
trace = self.generate_reasoning_trace(problem)
candidate_traces.append(trace)
best_trace = None
best_avg_score = 0.01
candidate_avg_scores = []
for trace in candidate_traces:
eval_results = self.evaluate_trace(problem, trace)
# Remove feedback from scores
scores = {k: v for k, v in eval_results.items() if k != "feedback"}
# Compute average score (assuming at least one score exists)
if scores:
avg_score = sum(scores.values()) / len(scores)
else:
avg_score = 0.0
candidate_avg_scores.append(avg_score)
# If the candidate meets the threshold and is the best, select it
if (
self._check_score_threshold(scores)
and avg_score > best_avg_score
):
best_trace = trace
best_avg_score = avg_score
if best_trace is None:
best_trace = candidate_traces[
candidate_avg_scores.index(max(candidate_avg_scores))
]
return best_trace
@retry_on_error()
def improve_trace(
self,
problem: str,
trace: str,
feedback: str,
solution: Optional[str] = None,
) -> str:
r"""Generate improved reasoning trace based on feedback.
Args:
problem (str): The original problem text.
trace (str): The current reasoning trace.
feedback (str): Feedback for improving the trace.
solution (Optional[str]): The solution to the problem, if provided.
(default: :obj:`None`)
Returns:
str: Improved reasoning trace.
"""
self.reason_agent.reset()
solution_text = f"Solution: {solution}" if solution else ""
prompt = self.IMPROVEMENT_TEMPLATE.format(
problem=problem,
trace=trace,
feedback=feedback,
solution=solution_text,
)
response = self.reason_agent.step(prompt)
return response.msg.content
def validate_problem_format(self, problem: Dict) -> None:
r"""Validate that a problem dictionary has the required format.
Args:
problem (Dict): Problem dictionary to validate.
Raises:
ValueError: If the problem format is invalid.
"""
if not isinstance(problem, dict):
raise ValueError("Problem must be a dictionary.")
# Check required problem field
if "problem" not in problem:
raise ValueError("Problem dictionary must contain 'problem' key.")
if not isinstance(problem["problem"], str):
raise ValueError("Problem 'problem' field must be a string.")
# Optional fields validation
optional_fields: dict[str, type | tuple[type, ...]] = {
"id": (str, int, type(None)),
"type": str,
"solution": str,
}
for field, expected_type in optional_fields.items():
if field in problem and not isinstance(
problem[field], expected_type
):
type_name = (
expected_type.__name__
if hasattr(expected_type, '__name__')
else str(expected_type)
)
raise ValueError(
f"Problem '{field}' must be of "
f"type {type_name} if present."
)
def _check_boxed_answers(self, solution: str, trace: str) -> bool:
r"""Check if the answer in the trace matches the solution using the
configured patterns.
Args:
solution (str): The problem solution string.
trace (str): The reasoning trace string.
Returns:
bool: True if answers match, False otherwise
"""
import re
# Extract content using the configured patterns
solution_match = re.search(self.solution_pattern, solution, re.DOTALL)
trace_match = re.search(self.trace_pattern, trace, re.DOTALL)
if solution_match and trace_match:
# Clean up whitespace and normalize content
solution_answer = solution_match.group(1).strip()
trace_answer = trace_match.group(1).strip()
return solution_answer == trace_answer
return False
def process_problem(
self, problem: Dict, rationalization: bool = False
) -> ProblemResult:
r"""Process a single problem through the self-improving cot pipeline.
Args:
problem (Dict): Problem dictionary containing the problem text.
rationalization (bool, optional): Whether to use rationalization.
(default: :obj:`False`)
Returns:
ProblemResult: Results with final trace and history.
Raises:
ValueError: If the problem format is invalid.
"""
# Validate problem format before processing
self.validate_problem_format(problem)
problem_text = problem["problem"]
solution_text = problem.get("solution", "")
current_trace = None
if self.rejection_sampling_n:
current_trace = self.generate_reasoning_trace_rejection(
problem_text
)
else:
current_trace = self.generate_reasoning_trace(problem_text)
improvement_history = []
scores = {}
# Only evaluate if evaluate_agent or reward_model is set
if self.evaluate_agent or self.reward_model:
# Create batches for parallel evaluation
batch_problems = [problem]
batch_traces = [current_trace]
batch_solutions = [solution_text]
# Evaluate current trace batch
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
eval_results = loop.run_until_complete(
self._batch_evaluate_traces(
batch_problems, batch_traces, batch_solutions
)
)
finally:
loop.close()
# Process evaluation results
eval_dict = eval_results[-1] # Get latest evaluation
scores = {k: v for k, v in eval_dict.items() if k != "feedback"}
# Record initial evaluation
if self.evaluator:
improvement_history.append(
TraceIteration(
iteration=0,
trace=current_trace,
evaluation=RewardTraceEvaluation(**eval_dict),
)
)
else:
improvement_history.append(
TraceIteration(
iteration=0,
trace=current_trace,
evaluation=AgentTraceEvaluation(
**scores, feedback=eval_dict["feedback"]
),
)
)
# Only do improvement iterations if max_iterations > 0
if self.max_iterations > 0:
for iteration in range(0, self.max_iterations):
# Check if quality threshold met
if self._check_score_threshold(scores):
break
# Generate improved trace
if rationalization:
current_trace = self.improve_trace(
problem_text,
current_trace,
eval_dict["feedback"],
solution_text,
)
else:
current_trace = self.improve_trace(
problem_text, current_trace, eval_dict["feedback"]
)
# Evaluate improved trace
batch_traces = [current_trace]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
eval_results = loop.run_until_complete(
self._batch_evaluate_traces(
batch_problems, batch_traces, batch_solutions
)
)
finally:
loop.close()
eval_dict = eval_results[-1]
scores = {
k: v for k, v in eval_dict.items() if k != "feedback"
}
# Record iteration history
if self.evaluator:
improvement_history.append(
TraceIteration(
iteration=iteration + 1,
trace=current_trace,
evaluation=RewardTraceEvaluation(**eval_dict),
)
)
else:
improvement_history.append(
TraceIteration(
iteration=iteration + 1,
trace=current_trace,
evaluation=AgentTraceEvaluation(
**scores, feedback=eval_dict["feedback"]
),
)
)
boxed_answer_success = self._check_boxed_answers(
problem.get("solution", ""), current_trace
)
result = ProblemResult(
id=problem.get("id", ""),
type=problem.get("type", ""),
problem=problem_text,
solution=problem.get("solution", ""),
final_trace=current_trace,
agent_evaluate_success=self._check_score_threshold(scores)
if scores
else None,
boxed_answer_success=boxed_answer_success,
improvement_history=improvement_history,
)
# Write result to file immediately if output path is specified
if self.output_path:
with self.lock:
try:
# Read existing results
with open(self.output_path, 'r') as f:
data = json.load(f)
cleaned_result = self.clean_json(result.model_dump())
data['traces'].append(cleaned_result)
self.safe_write_json(self.output_path, data)
except Exception as e:
logger.error(f"Error writing result to file: {e}")
return result
def generate(self, rationalization: bool = False) -> List[Dict[str, Any]]:
r"""Execute the self-improving cot pipeline on all problems.
Process problems and return results. If output_path is specified,
also save results to file.
Args:
rationalization (bool, optional): Whether to use rationalization.
(default: :obj:`False`)
Returns:
List[Dict[str, Any]]: List of processed results
"""
# Pre-allocate results list
self.reasoning_traces = []
# Process problems in batches
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
results = loop.run_until_complete(
self._batch_process_problems(self.problems, rationalization)
)
finally:
loop.close()
self.reasoning_traces = [result.model_dump() for result in results]
return self.reasoning_traces
# Templates for generating reasoning, evaluation and improving them.
REASONING_TEMPLATE = """Let's solve this step by step:
Problem: {problem}
1. First, let's understand what we're asked
2. Let's break this down into parts
3. Let's solve each part systematically
4. Finally, let's verify our solution
{few_shot_examples}
Please show your complete reasoning process."""
EVALUATION_TEMPLATE = """Please evaluate this reasoning trace and
provide scores and feedback in valid JSON format.
Problem: {problem}
{solution}
Reasoning Trace:
{trace}
Evaluate for:
1. Correctness (Is each step logically sound?)
2. Clarity (Is the explanation clear and well-structured?)
3. Completeness (Are all necessary steps included?)
Respond ONLY with a JSON object in this exact format:
{{
"correctness": <score between 0 and 1>,
"clarity": <score between 0 and 1>,
"completeness": <score between 0 and 1>,
"feedback": "<specific feedback for improvement>"
}}"""
IMPROVEMENT_TEMPLATE = """Based on this feedback, generate an
improved reasoning trace:
Problem: {problem}
{solution}
Previous Trace:
{trace}
Feedback:
{feedback}
Generate a new, improved reasoning trace that addresses the feedback."""

View File

@@ -0,0 +1,36 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .filter import (
FILTER_REGISTRY,
FilterFunction,
InstructionFilter,
KeywordFilter,
LengthFilter,
NonEnglishFilter,
PunctuationFilter,
RougeSimilarityFilter,
)
from .self_instruct import SelfInstructPipeline
__all__ = [
'SelfInstructPipeline',
'InstructionFilter',
'NonEnglishFilter',
'PunctuationFilter',
'RougeSimilarityFilter',
'FilterFunction',
'KeywordFilter',
'LengthFilter',
'FILTER_REGISTRY',
]

View File

@@ -0,0 +1,34 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .filter_function import (
FilterFunction,
KeywordFilter,
LengthFilter,
NonEnglishFilter,
PunctuationFilter,
RougeSimilarityFilter,
)
from .filter_registry import FILTER_REGISTRY
from .instruction_filter import InstructionFilter
__all__ = [
"LengthFilter",
"NonEnglishFilter",
"PunctuationFilter",
"RougeSimilarityFilter",
"FilterFunction",
"KeywordFilter",
"InstructionFilter",
"FILTER_REGISTRY",
]

View File

@@ -0,0 +1,216 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import re
from abc import ABC, abstractmethod
from typing import List
from rouge import Rouge
from camel.models.reward import BaseRewardModel
class FilterFunction(ABC):
r"""A base abstract class for filter functions.
Subclasses must implement the `apply` method, which determines whether
a given instruction passes the filter criteria.
"""
@abstractmethod
def apply(self, instruction: str) -> bool:
r"""Evaluate the given instruction based on the filter's criteria.
Args:
instruction (str): The instruction to evaluate.
Returns:
bool: True if the instruction passes the filter, False otherwise.
"""
pass
class LengthFilter(FilterFunction):
r"""Filters instructions based on their word count.
Args:
min_len (int): The minimum word count required for an instruction.
(default::obj:`5`)
max_len (int): The maximum word count allowed for an instruction.
(default::obj:`200`)
"""
def __init__(self, min_len: int = 5, max_len: int = 200):
self.min_len = min_len
self.max_len = max_len
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): the instruction to be filtered.
Returns:
bool: True if the length of the instruction is within the range
of [min_len, max_len]
"""
word_count = len(instruction.split())
return self.min_len <= word_count <= self.max_len
class KeywordFilter(FilterFunction):
r"""Filters instructions that contain specific undesirable keywords.
Args:
keywords (List[str]): A list of keywords to filter out.
"""
def __init__(self, keywords: List[str]):
self.keywords = [keyword.lower() for keyword in keywords]
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): the instruction to be filtered.
Returns:
bool: True Instruction must NOT contain any of the keywords.
"""
lower_instr = instruction.lower()
return not any(keyword in lower_instr for keyword in self.keywords)
class PunctuationFilter(FilterFunction):
r"""Filters instructions that begin with a non-alphanumeric character."""
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): the instruction to be filtered.
Returns:
bool: True if the instruction does not start with punctuation.
"""
return not re.match(r'^[^\w\s]', instruction)
class NonEnglishFilter(FilterFunction):
r"""Filters instructions that do not begin with English letters."""
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): the instruction to be filtered.
Returns:
bool: True if the instruction starts with an English letter.
"""
return bool(re.match(r'^[A-Za-z]', instruction))
class RougeSimilarityFilter(FilterFunction):
r"""Filters instructions that are too similar to existing instructions
based on ROUGE scores.
Args:
existing_instructions (List[str]): A list of existing instructions to
compare against.
threshold (float): The similarity threshold for filtering.
(default::obj:`0.7`)
"""
def __init__(
self, existing_instructions: List[str], threshold: float = 0.7
):
self.existing_instructions = existing_instructions
self.threshold = threshold
self.rouge = Rouge()
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): the instruction to be filtered.
Returns:
bool: True if the instruction's similarity to any existing
instruction is below the threshold.
"""
if not self.existing_instructions:
return True
for existing_instr in self.existing_instructions:
scores = self.rouge.get_scores(instruction, existing_instr)
score = scores[0]['rouge-l']['f']
if score > self.threshold:
return False
return True
class RewardModelFilter(FilterFunction):
r"""Filters instructions based on scores provided by a reward model.
Args:
reward_model (BaseRewardModel): The reward model used to evaluate
the instructions.
threshold (float): The minimum score required for an instruction
to pass the filter.
"""
def __init__(
self,
reward_model: BaseRewardModel,
threshold: float = 0.5,
):
self.prompt = ""
self.reward_model = reward_model
self.threshold = threshold
def apply(self, instruction: str) -> bool:
r"""Filter the instruction
Args:
instruction (str): The instruction to be filtered.
Returns:
bool: True if the instruction's score is above the threshold.
Raises:
ValueError: ValueError: If `score_types` is empty or if the
required score is not found in `scores`.
"""
data = [
{"role": "user", "content": self.prompt},
{"role": "assistant", "content": instruction},
]
scores = self.reward_model.evaluate(data)
score_types = self.reward_model.get_scores_types()
if not score_types:
raise ValueError("No score types available from the reward model.")
score_type = score_types[0]
score = scores.get(score_type, None)
if score is None:
raise ValueError(
f"Score type '{score_type}' is not found in the "
"evaluation scores."
)
return score >= self.threshold

View File

@@ -0,0 +1,56 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Callable, Dict
from .filter_function import (
FilterFunction,
KeywordFilter,
LengthFilter,
NonEnglishFilter,
PunctuationFilter,
RewardModelFilter,
RougeSimilarityFilter,
)
FILTER_REGISTRY: Dict[str, Callable[[Dict[str, Any]], FilterFunction]] = {
"length": lambda kwargs: LengthFilter(
min_len=kwargs.get("min_len", 5), max_len=kwargs.get("max_len", 200)
),
"keyword": lambda kwargs: KeywordFilter(
keywords=kwargs.get("keywords", ["image", "data"])
),
"punctuation": lambda kwargs: PunctuationFilter(),
"non_english": lambda kwargs: NonEnglishFilter(),
"rouge_similarity": lambda kwargs: RougeSimilarityFilter(
existing_instructions=kwargs.get("existing_instructions", []),
threshold=kwargs.get("threshold", 0.7),
),
"reward": lambda kwargs: RewardModelFilter(
reward_model=kwargs.get("reward_model"), # type:ignore[arg-type]
threshold=kwargs.get("threshold", 0.7),
),
}
def register_filter(
name: str, constructor: Callable[[Dict[str, Any]], FilterFunction]
):
r"""Registers a new filter constructor in FILTER_REGISTRY.
Args:
name (str): Unique name of the filter.
constructor (Callable[[Dict[str, Any]], FilterFunction]): Function to
create the filter using a dictionary of parameters.
"""
FILTER_REGISTRY[name] = constructor

View File

@@ -0,0 +1,97 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Tuple, Union
from camel.logger import get_logger
from .filter_function import FilterFunction, RewardModelFilter
from .filter_registry import FILTER_REGISTRY
logger = get_logger(__name__)
class InstructionFilter:
def __init__(
self,
filters_config: Dict[str, Dict[str, Any]],
stop_on_first_failure: bool = False,
):
r"""Initialize the InstructionFilter with a dictionary of filter
configurations.
Args:
filters_config(Dict[str, Dict[str, Any]]):
Example filters_config:
{
"length": {"min_len": 5, "max_len": 100},
"keyword": {"keywords": ["image", "video"]},
"non_english": {},
"rouge_similarity": {
"existing_instructions": ["Some existing text"],
"threshold": 0.6
}
}
Each key in filters_config corresponds to a filter name
(registered in FILTER_REGISTRY).
Each value is a dict of parameters for that filter.
stop_on_first_failure (bool): If True, stops checking filters after
the first failure.
"""
self.filters: List[FilterFunction] = []
for filter_name, params in filters_config.items():
if filter_name not in FILTER_REGISTRY:
raise ValueError(f"Unknown filter function: {filter_name}")
self.filters.append(FILTER_REGISTRY[filter_name](params))
self.stop_on_first_failure: bool = stop_on_first_failure
def add_filter(self, filter_function: FilterFunction):
r"""Add a custom filter function to the InstructionFilter.
This allows adding filters that are not in the registry.
Args:
filter_function (FilterFunction): The filter function to be added
"""
self.filters.append(filter_function)
def filter(
self, prompt: str, instruction: str, return_details: bool = False
) -> Union[bool, Tuple[bool, List[str]]]:
r"""Check if the given instruction passes all filter functions.
Args:
prompt (str): The prompt of generating the instruction.
instruction (str): The instruction to evaluate.
return_details (bool): If True, returns a tuple (bool, List[str])
where the list contains the names of filters that failed.
(default::obj:`False`)
Returns:
bool: True if the instruction passes all filters, False otherwise.
OR (bool, List[str]) if return_details is True.
"""
failed_filters = []
for f in self.filters:
if isinstance(f, RewardModelFilter):
f.prompt = prompt
if not f.apply(instruction):
failed_filters.append(type(f).__name__)
logger.warning(
f"{type(f).__name__} failed instruction: {instruction}"
)
if self.stop_on_first_failure:
break
if return_details:
return len(failed_filters) == 0, failed_filters
return len(failed_filters) == 0

View File

@@ -0,0 +1,445 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import os
import random
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from camel.agents import ChatAgent
from camel.logger import get_logger
from .filter import RougeSimilarityFilter
from .filter.instruction_filter import InstructionFilter
from .templates import SelfInstructTemplates
logger = get_logger(__name__)
class SelfInstructPipeline:
r"""A pipeline to generate and manage machine-generated instructions for
tasks, combining human and machine task samples.
Args:
agent (ChatAgent): The agent used to interact and generate
instructions.
seed (str): The path to the human-written instructions.
num_machine_instructions (int): Number of machine-generated
instructions to generate. (default::obj:`5`)
data_output_path (Optional[str]): Path to save the generated data.
(default::obj:`./data_output.json`)
human_to_machine_ratio (tuple): Ratio of human to machine tasks used
for instruction generation. (default::obj:`(6, 2)`)
instruction_filter (InstructionFilter): A filter to validate
generated instructions. (default::obj:`None`)
filter_config (Optional[Dict[str, Dict[str, Any]]]): configuration
for the filter functions registered in FILE_REGISTRY.
(default::obj:`None`)
stop_on_first_failure (bool): If True, stops checking filters after
the first failure.
"""
def __init__(
self,
agent: ChatAgent,
seed: str,
num_machine_instructions: int = 5,
data_output_path: Optional[str] = './data_output.json',
human_to_machine_ratio: tuple = (6, 2),
instruction_filter: Optional[InstructionFilter] = None,
filter_config: Optional[Dict[str, Dict[str, Any]]] = None,
stop_on_first_failure: bool = False,
):
self.agent = agent
self.num_machine_instructions = num_machine_instructions
self.data_output_path = data_output_path
self.human_to_machine_ratio = human_to_machine_ratio
self.human_tasks: List[Dict] = []
self.machine_tasks: List[Dict] = []
self.load_seed(seed)
default_config: Dict[str, Dict[str, Any]] = {
"length": {},
"keyword": {},
"punctuation": {},
"non_english": {},
"rouge_similarity": {},
}
if instruction_filter is not None:
# custom
self.instruction_filter = instruction_filter
else:
# default
config_to_use = (
filter_config if filter_config is not None else default_config
)
self.instruction_filter = InstructionFilter(
config_to_use, stop_on_first_failure
)
def load_seed(self, path: str):
r"""Load seed tasks from a file. Defaults to a predefined seed file if
no path is provided.
Args:
path (str): Path to the seed file.
Raises:
FileNotFoundError: If the seed file does not exist.
"""
if os.path.exists(path):
with open(path, 'r') as f:
for line in f:
line = line.strip()
if line:
self.human_tasks.append(json.loads(line))
else:
raise FileNotFoundError(f"Seed file not found at path: {path}")
def sample_human_tasks(self, count: int) -> List[dict]:
r"""Sample a specified number of human tasks from the loaded seed.
Args:
count (int): Number of human tasks to sample.
Returns:
List[dict]: A list of sampled human tasks.
"""
return random.sample(
self.human_tasks, min(count, len(self.human_tasks))
)
def sample_machine_tasks(self, count: int) -> List[dict]:
r"""Sample a specified number of machine tasks.
Args:
count (int): Number of machine tasks to sample.
Returns:
List[dict]: A list of sampled machine tasks, with placeholders if
insufficient tasks are available.
"""
available_machine_tasks = len(self.machine_tasks)
if available_machine_tasks < count:
sampled_tasks = self.machine_tasks.copy()
placeholders_needed = count - available_machine_tasks
sampled_tasks.extend(
[{'instruction': ""} for _ in range(placeholders_needed)]
)
return sampled_tasks
return random.sample(self.machine_tasks, count)
def generate_machine_instruction(self) -> List:
r"""Generate a machine instruction using the agent.
Combines human and machine tasks based on the configured ratio to
create a prompt for instruction generation.
Returns:
List: The prompt and a machine-generated instruction.
"""
sampled_human_tasks = self.sample_human_tasks(
self.human_to_machine_ratio[0]
)
sampled_machine_tasks = self.sample_machine_tasks(
self.human_to_machine_ratio[1]
)
prompt = "Below are some tasks:\n\n"
for idx, task in enumerate(sampled_human_tasks, 1):
prompt += f"Task {idx}: {task['instruction']}\n"
current_task_number = len(sampled_human_tasks) + 1
for idx, task in enumerate(sampled_machine_tasks, current_task_number):
prompt += f"Task {idx}: {task['instruction']}\n"
task_num = len(sampled_human_tasks) + len(sampled_machine_tasks) + 1
prompt += f"Task {task_num}:"
prompt += (
"\nNow, please produce exactly one new task that fits the "
"style of the ones above.\n Do not include any task numbering or "
"labels like 'Task X:'. Just write the task itself.\n"
"The task should be a single sentence.\n\n"
)
response = self.agent.step(prompt)
self.agent.reset()
generated_tasks = [
line.strip()
for line in response.msgs[0].content.split("\n")
if line.strip()
]
return [prompt, generated_tasks[0]]
def identify_instruction(self, instruction: str) -> bool:
r"""Determine if the given instruction is a classification task.
Args:
instruction (str): The instruction to classify.
Returns:
bool: True if the instruction is a classification task,
otherwise False.
"""
clf_prompt = (
SelfInstructTemplates.clf_template
+ f"Task: {instruction}\nIs it classification?"
+ "\nRespond in the following structured format:"
"\n{\n \"answer\": true\n}\n"
"or\n"
"{\n \"answer\": false\n}\n"
)
response = self.agent.step(clf_prompt)
self.agent.reset()
try:
structured_response = AgentResponse.parse_raw(
response.msgs[0].content.strip()
)
return structured_response.answer
except ValueError as e:
logger.error(f"Error parsing agent response: {e}")
return False
def generate_machine_instances(self):
r"""Generate instances for each machine task based on its
classification status.
"""
logger.info(
f"Starting output generation: target {len(self.machine_tasks)} "
f"instructions"
)
attempt_count = 0
for instruction in self.machine_tasks:
instance = self.generate_machine_instance(
instruction['instruction'], instruction['is_classification']
)
instruction['instances'] = instance
attempt_count += 1
logger.info(
f"Attempt[Output]: Progress {attempt_count}/"
f"{len(self.machine_tasks)} instructions"
)
def generate_machine_instance(
self, instruction: str, classification: bool
) -> list[dict]:
r"""Generate instances for a given instruction.
Args:
instruction (str): The instruction to create instances for.
classification (bool): Whether the instruction is a classification
task.
Returns:
List[dict]: A list of generated instances in input-output format.
"""
if classification:
prompt = (
SelfInstructTemplates.output_first_template_for_clf.format(
instruction=instruction
)
)
else:
prompt = SelfInstructTemplates.input_first_template_for_gen.format(
instruction=instruction
)
response = self.agent.step(prompt)
self.agent.reset()
generated_text = response.msgs[0].content.strip()
if classification:
return self.parse_classification_output(generated_text)
else:
return self.parse_non_classification_output(generated_text)
def parse_classification_output(
self, generated_text: str
) -> List[Dict[str, str]]:
r"""Parse the generated text for classification tasks into input-output
pairs.
Args:
generated_text (str): The raw text generated by the agent for
classification tasks.
Returns:
List[Dict[str, str]]: A list of dictionaries with 'input' and
'output' keys.
"""
instances = []
lines = generated_text.split("\n")
current_label = None
current_input = None
for line in lines:
line = line.strip()
if not line:
continue
if line.startswith("Class label:"):
if current_label and current_input:
instances.append(
{
"input": current_input.strip(),
"output": current_label.strip(),
}
)
current_label = line[len("Class label:") :].strip()
current_input = None
else:
if current_input is None:
current_input = line
else:
current_input += f"\n{line}"
if current_label and current_input:
instances.append(
{
"input": current_input.strip(),
"output": current_label.strip(),
}
)
return instances
def parse_non_classification_output(
self, generated_text: str
) -> List[Dict[str, str]]:
r"""Parse the generated text for non-classification tasks into
input-output pairs.
Args:
generated_text (str): The raw text generated by the agent for
non-classification tasks.
Returns:
List[Dict[str, str]]: A list of dictionaries with 'input' and
'output' keys.
"""
instances = []
prev = 0
lines = generated_text.split("\n")
i = 0
while i < len(lines):
line = lines[i].strip()
if line.startswith("Example "):
prev = i + 1
elif line.startswith("Output:"):
instance_input = '\n'.join(lines[prev:i]).strip()
if instance_input.startswith("Input: "):
instance_input = instance_input[len("Input: ") :].strip()
else:
instance_input = instance_input.strip()
instance_output = line[len("Output:") :].strip()
i += 1
while i < len(lines) and not lines[i].strip().startswith(
"Example "
):
instance_output += '\n' + lines[i].strip()
i += 1
i -= 1
instance_output = instance_output.strip()
instances.append(
{"input": instance_input, "output": instance_output}
)
prev = i + 1
i += 1
if not instances:
instances.append({"input": "", "output": "No valid output found."})
return instances
def construct_data(self):
r"""Save the machine-generated tasks to the specified output path
in JSON format.
"""
with open(self.data_output_path, 'w') as f:
json.dump(self.machine_tasks, f, indent=4, ensure_ascii=False)
def generate(self, timeout_minutes=600):
r"""Execute the entire pipeline to generate machine instructions
and instances.
Args:
timeout_minutes (int): Maximum time in minutes to run the
generation process before timing out. (default: :obj:`600`)
"""
start_time = time.time()
timeout_seconds = timeout_minutes * 60
logger.info(
f"Starting instruction generation: target "
f"{self.num_machine_instructions} instructions"
)
while len(self.machine_tasks) < self.num_machine_instructions:
# Check for timeout
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
logger.info(
f"Generation timed out after {elapsed / 60:.1f} minutes. "
f"Generated {len(self.machine_tasks)}/"
f"{self.num_machine_instructions} instructions."
)
break
prompt, instruction = self.generate_machine_instruction()
existing_instructions = [
t["instruction"] for t in self.human_tasks
] + [t["instruction"] for t in self.machine_tasks]
for f in self.instruction_filter.filters:
if isinstance(f, RougeSimilarityFilter):
f.existing_instructions = existing_instructions
if self.instruction_filter.filter(prompt, instruction):
instruction_dict = {
"id": f"machine_task_{len(self.machine_tasks) + 1}",
"instruction": instruction,
"is_classification": self.identify_instruction(
instruction
),
}
self.machine_tasks.append(instruction_dict)
logger.info(
f"Attempt[Instruction]: Progress "
f"{len(self.machine_tasks)}/"
f"{self.num_machine_instructions} "
f"instructions"
)
else:
logger.warning(
f"Instruction failed filters. Skipping instruction: "
f"{instruction}"
)
self.generate_machine_instances()
self.construct_data()
class AgentResponse(BaseModel):
answer: bool = Field(
...,
description="Indicates whether the task is "
"classification (True/False).",
)

View File

@@ -0,0 +1,382 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from dataclasses import dataclass
# flake8: noqa
@dataclass(frozen=True)
class SelfInstructTemplates:
r"""Contains templates prompts for self-instruct data generation"""
clf_template = """ '''Can the following task be regarded as a classification task with finite output labels?
Task: Given my personality and the job, tell me if I would be suitable.
Is it classification? Yes
Task: Give me an example of a time when you had to use your sense of humor.
Is it classification? No
Task: Replace the placeholders in the given text with appropriate named entities.
Is it classification? No
Task: Fact checking - tell me if the statement is true, false, or unknown, based on your knowledge and common sense.
Is it classification? Yes
Task: Return the SSN number for the person.
Is it classification? No
Task: Detect if the Reddit thread contains hate speech.
Is it classification? Yes
Task: Analyze the sentences below to identify biases.
Is it classification? No
Task: Select the longest sentence in terms of the number of words in the paragraph, output the sentence index.
Is it classification? Yes
Task: Find out the toxic word or phrase in the sentence.
Is it classification? No
Task: Rank these countries by their population.
Is it classification? No
Task: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include: Music, Sports, Politics, Tech, Finance, Basketball, Soccer, Tennis, Entertainment, Digital Game, World News. Output its categories one by one, seperated by comma.
Is it classification? Yes
Task: Given the name of an exercise, explain how to do it.
Is it classification? No
Task: Select the oldest person from the list.
Is it classification? Yes
Task: Find the four smallest perfect numbers.
Is it classification? No
Task: Does the information in the document supports the claim? You can answer "Support" or "Unsupport".
Is it classification? Yes
Task: Create a detailed budget for the given hypothetical trip.
Is it classification? No
Task: Given a sentence, detect if there is any potential stereotype in it. If so, you should explain the stereotype. Else, output no.
Is it classification? No
Task: Explain the following idiom to me, and try to give me some examples.
Is it classification? No
Task: Is there anything I can eat for a breakfast that doesn't include eggs, yet includes protein, and has roughly 700-1000 calories?
Is it classification? No
Task: Answer the following multiple choice question. Select A, B, C, or D for the final answer.
Is it classification? Yes
Task: Decide whether the syllogism is logically sound.
Is it classification? Yes
Task: How can individuals and organizations reduce unconscious bias?
Is it classification? No
Task: What are some things you can do to de-stress?
Is it classification? No
Task: Find out the largest one from a set of numbers. Output the number directly.
Is it classification? Yes
Task: Replace the <mask> token in the text with proper words that are consistent with the context. You can use multiple words for each <mask> token.
Is it classification? No
Task: Write a cover letter based on the given facts.
Is it classification? No
Task: Identify the pos tag of the word in the given sentence.
Is it classification? Yes
Task: Write a program to compute the sum of integers from k to n.
Is it classification? No
Task: In this task, you need to compare the meaning of the two sentences and tell if they are the same. Output yes or no.
Is it classification? Yes
Task: To make the pairs have the same analogy, write the fourth word.
Is it classification? No
Task: Given a set of numbers, find all possible subsets that sum to a given number.
Is it classification? No
"""
output_first_template_for_clf = '''You are given a classification instruction.
Produce multiple labeled examples following the format below. For each example:
- Begin with a "Class label:" line identifying one possible category.
- Follow that with one line specifying the example input (e.g., "Sentence:", "Dialogue:", "Opinion:", or "Email:").
- The content after these lines should serve as an illustrative example of that label.
Do not restate or include the "Task:" line. Do not add additional commentary. Just produce the labeled examples.
Example format (no initial task line, task will be provided) when task is Task: Classify the sentiment of the sentence into positive, negative, or mixed.:
Class label: mixed
Sentence: I enjoy the flavor of the restaurant but their service is too slow.
Class label: Positive
Sentence: I had a great day today. The weather was beautiful and I spent time with friends and family.
Class label: Negative
Sentence: I was really disappointed by the latest superhero movie. I would not recommend it to anyone.
Below are more examples:
Task: Given a dialogue, classify whether the user is satisfied with the service. You should respond with "Satisfied" or "Unsatisfied".
Class label: Satisfied
Dialogue:
- Agent: Thank you for your feedback. We will work to improve our service in the future.
- Customer: I am happy with the service you provided. Thank you for your help.
Class label: Unsatisfied
Dialogue:
- Agent: I am sorry we will cancel that order for you, and you will get a refund within 7 business days.
- Customer: oh that takes too long. I want you to take quicker action on this.
Task: Given some political opinions, classify whether the person belongs to Democrats or Republicans.
Class label: Democrats
Opinion: I believe that everyone should have access to quality healthcare regardless of their income level.
Class label: Republicans
Opinion: I believe that people should be able to keep more of their hard-earned money and should not be taxed at high rates.
Task: Tell me if the following email is a promotion email or not.
Class label: Promotion
Email: Check out our amazing new sale! We've got discounts on all of your favorite products.
Class label: Not Promotion
Email: We hope you are doing well. Let us know if you need any help.
Task: Detect if the Reddit thread contains hate speech.
Class label: Hate Speech
Thread: All people of color are stupid and should not be allowed to vote.
Class label: Not Hate Speech
Thread: The best way to cook a steak on the grill.
Task: Does the information in the document supports the claim? You can answer "Support" or "Unsupport".
Class label: Unsupport
Document: After a record-breaking run that saw mortgage rates plunge to all-time lows and home prices soar to new highs, the U.S. housing market finally is slowing. While demand and price gains are cooling, any correction is likely to be a modest one, housing economists and analysts say. No one expects price drops on the scale of the declines experienced during the Great Recession.
Claim: The US housing market is going to crash soon.
Class label: Support
Document: The U.S. housing market is showing signs of strain, with home sales and prices slowing in many areas. Mortgage rates have risen sharply in recent months, and the number of homes for sale is increasing. This could be the beginning of a larger downturn, with some economists predicting a potential housing crash in the near future.
Claim: The US housing market is going to crash soon.
Task: Answer the following multiple-choice question. Select A, B, C, or D for the final answer.
Class label: C
Question: What is the capital of Germany?
A. London
B. Paris
C. Berlin
D. Rome
Class label: D
Question: What is the largest planet in our solar system?
A) Earth
B) Saturn
C) Mars
D) Jupiter
Class label: A
Question: What is the process by which plants make their own food through photosynthesis?
A) Respiration
B) Fermentation
C) Digestion
D) Metabolism
Class label: B
Question: Who wrote the novel "The Great Gatsby"?
A) Ernest Hemingway
B) F. Scott Fitzgerald
C) J.D. Salinger
D) Mark Twain
Task: You need to read a code and detect if there is a syntax error or not. Output true if there is an error, output false if there is not.
Class label: true
Code:
def quick_sort(arr):
if len(arr) < 2
return arr
Class label: False
Code:
def calculate_average(numbers):
total = 0
for number in numbers:
total += number
return total / len(numbers)
Task: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include Sports and Politics. Output its categories one by one, separated by a comma.
Class label: Sports
Article: The Golden State Warriors have won the NBA championship for the second year in a row.
Class label: Politics
Article: The United States has withdrawn from the Paris Climate Agreement.
Class label: Politics, Sports
Article: The government has proposed cutting funding for youth sports programs.
Task: Given a credit card statement, the cardholder's spending habits, and the account balance, classify whether the cardholder is at risk of defaulting on their payments or not.
Class label: At risk
Credit card statement: Purchases at high-end clothing stores and luxury hotels.
Cardholder's spending habits: Frequent purchases at luxury brands and high-end establishments.
Account balance: Over the credit limit and multiple missed payments.
Class label: Not at risk
Credit card statement: Purchases at grocery stores and gas stations.
Cardholder's spending habits: Regular purchases for necessary expenses and occasional dining out.
Account balance: Slightly below the credit limit and no missed payments.
Task: Given a social media post, the hashtags used, and a topic. classify whether the post is relevant to the topic or not.
Class label: Relevant
Post: I can't believe the government is still not taking action on climate change. It's time for us to take matters into our own hands.
Hashtags: #climatechange #actnow
Topic: Climate change
Class label: Not relevant
Post: I just bought the new iPhone and it is amazing!
Hashtags: #apple #technology
Topic: Travel
Task: The answer will be 'yes' if the provided sentence contains an explicit mention that answers the given question. Otherwise, answer 'no'.
Class label: Yes
Sentence: Jack played basketball for an hour after school.
Question: How long did Jack play basketball?
Class label: No
Sentence: The leaders of the Department of Homeland Security now appear before 88 committees and subcommittees of Congress.
Question: How often are they required to appear?
Task: Tell me what's the second largest city by population in Canada.
Class label: Montreal
Task: Classifying different types of mathematical equations, such as linear, and quadratic equations, based on the coefficients and terms in the equation.
Class label: Linear equation
Equation: y = 2x + 5
Class label: Quadratic equation
Equation: y = x^2 - 4x + 3
Task: Tell me the first number of the given list.
Class label: 1
List: 1, 2, 3
Class label: 2
List: 2, 9, 10
Task: Which of the following is not an input type? (a) number (b) date (c) phone number (d) email address (e) all of these are valid inputs.
Class label: (e)
Now, using the given instruction, produce several formatted examples accordingly:
Task: {instruction}
'''
input_first_template_for_gen = '''You will be given a task,
Your job is to generate at most two example instances demonstrating how to
perform this task. For each instance:
- If the task requires input (as an actual example of the task), provide it.
- If the task can be answered directly without requiring input, omit the input section.
Example 1
Input: [Provide input here if needed, otherwise omit this section]
Output: [Provide the correct output]
Example 2
Input: [Provide input here if needed, otherwise omit this section]
Output: [Provide the correct output]
Do not include any additional commentary, explanations, or more than two instances.
Below are some examples:
Task: Which exercises are best for reducing belly fat at home?
Output:
- Lying Leg Raises
- Leg In And Out
- Plank
- Side Plank
- Sit-ups
Task: Extract all the country names in the paragraph, list them separated by commas.
Example 1
Paragraph: Dr. No is the sixth novel by the English author Ian Fleming to feature his British Secret Service agent James Bond. Written at Fleming's Goldeneye estate in Jamaica, it was first published in the United Kingdom by Jonathan Cape in 1958. In the novel Bond looks into the disappearance in Jamaica of two fellow MI6 operatives who had been investigating Doctor No. Bond travels to No's Caribbean island and meets Honeychile Rider, who is there to collect shells. They are captured and taken to a luxurious facility carved into a mountain. The character of Doctor No, the son of a German missionary and a Chinese woman, was influenced by Sax Rohmer's Fu Manchu stories. Dr. No was the first of Fleming's novels to face widespread negative reviews in Britain, but it was received more favourably in the United States.
Output: English, British, Jamaica, the United Kingdom, German, Chinese, Britain, the United States.
Task: Converting 85 F to Celsius.
Output: 85°F = 29.44°C
Task: Sort the given list ascendingly.
Example 1
List: [10, 92, 2, 5, -4, 92, 5, 101]
Output: [-4, 2, 5, 5, 10, 92, 92, 101]
Example 2
Input 2 - List: [9.99, 10, -5, -1000, 5e6, 999]
Output: [-1000, -5, 9.99, 10, 999, 5e6]
Task: Suggest a better and more professional rephrasing of the following sentence.
Example 1
Sentence: This house is surprisingly not constructed very well, and you probably need more money to fix it after you buy it. If you ask me, I would suggest you to consider other candidates.
Output: This house does not seem to be constructed well, so you may need to spend more money to fix it after you purchase it. I would suggest that you look at other properties.
Example 2
Sentence: Just so you know, we did an experiment last week and found really surprising results - language model can improve itself!
Output: Our experiments last week demonstrated surprising results, proving that the language model can improve itself.
Task: Read the following paragraph and answer a math question about the paragraph. You need to write out the calculation for getting the final answer.
Example 1
Paragraph: Gun violence in the United States results in tens of thousands of deaths and injuries annually, and was the leading cause of death for children 19 and younger in 2020. In 2018, the most recent year for which data are available as of 2021, the Centers for Disease Control and Prevention's (CDC) National Center for Health Statistics reports 38,390 deaths by firearm, of which 24,432 were by suicide. The rate of firearm deaths per 100,000 people rose from 10.3 per 100,000 in 1999 to 12 per 100,000 in 2017, with 109 people dying per day or about 14,542 homicides in total, being 11.9 per 100,000 in 2018. In 2010, there were 19,392 firearm-related suicides, and 11,078 firearm-related homicides in the U.S. In 2010, 358 murders were reported involving a rifle while 6,009 were reported involving a handgun; another 1,939 were reported with an unspecified type of firearm. In 2011, a total of 478,400 fatal and nonfatal violent crimes were committed with a firearm.
Question: How many more firearm-related deaths were there in 2018 compared to 2010?
Output:
38390 - (19392 + 11078) = 38390 - 30470 = 7920.
So, in 2018, there were 7920 more deaths by firearm than in 2010.
Task: Write Python code to solve this leetcode problem.
Example 1
Problem: You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order, and each of their nodes contains a single digit. Add the two numbers and return the sum as a linked list. You may assume the two numbers do not contain any leading zero, except the number 0 itself.
Output:
class Solution(object):
def addTwoNumbers(self, l1, l2):
carry = 0
root = n = ListNode(0)
while l1 or l2 or carry:
v1 = v2 = 0
if l1:
v1 = l1.val
l1 = l1.next
if l2:
v2 = l2.val
l2 = l2.next
carry, val = divmod(v1+v2+carry, 10)
n.next = ListNode(val)
n = n.next
return root.next
Task: Solve the equation and find the value of X. Show your steps.
Example 1
Equation: 10X + 5 = 10
Output: 10X = 5, X = 0.5
Example 2
Equation: X + Y + 120 = 100
Output: X + Y = -20, X = -20 - Y
Task: Write a program to compute the sum of integers from k to n.
Output:
def sum(k, n):
sum = 0
for i in range(k, n+1):
sum += i
return sum
Task: Select the oldest person from the given list.
Example 1
List: George Washington, Confucius, Michael Jordan, Michelangelo
Output: Confucious
Example 2
List: Alan Turing, Geoffrey Hinton, Yann LeCun, Yoshua Bengio
Output: Alan Turing
Task: Turn down a job offer by sending an email to a recruiter explaining the reason.
Output: Hi [Recruiter],
Thank you so much for the generous offer to join your team. As we discussed, Ive admired the company for a number of years, and am a proud endorser of its products. However, after further consideration of where I currently am in my career, Ive decided to accept an offer at another company.
I would love to stay in touch with you and have already started following you on [Social Media Platform]. Again, thank you so much for your time and consideration.
Thanks again,
[Your Name]
Task: {instruction}
'''

View File

@@ -0,0 +1,31 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .data_processor import (
DataCurator,
ExampleConstructor,
UserDataProcessor,
)
from .models import MultiHopQA, ReasoningStep
from .user_data_processor_config import (
ProcessorConfig,
)
__all__ = [
"DataCurator",
"ExampleConstructor",
"ProcessorConfig",
"UserDataProcessor",
"ReasoningStep",
"MultiHopQA",
]

View File

@@ -0,0 +1,538 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
from typing import Any, Dict, List, Optional, Sequence
from tqdm import tqdm
from camel.agents.multi_hop_generator_agent import MultiHopGeneratorAgent
from camel.datagen.source2synth.user_data_processor_config import (
ProcessorConfig,
)
from camel.logger import get_logger
logger = get_logger(__name__)
class UserDataProcessor:
r"""A processor for generating multi-hop question-answer pairs from user
data.
This class handles the processing of text data to generate multi-hop
question-answer pairs using either an AI model or rule-based approaches.
It manages the entire pipeline from text preprocessing to dataset curation.
Attributes:
config (ProcessorConfig): Configuration for data processing parameters.
rng (random.Random): Random number generator for reproducibility.
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for
generating QA pairs.
"""
def __init__(self, config: Optional[ProcessorConfig] = None):
r"""Initialize the UserDataProcessor.
Args:
config (Optional[ProcessorConfig], optional): Configuration for
data processing. (default: :obj:`None`)
"""
self.config = config or ProcessorConfig()
self.rng = random.Random(self.config.seed)
self.multi_hop_agent = (
self.config.hop_generating_agent
if self.config.use_ai_model
else None
)
def process_text(
self, text: str, source: str = "user_input"
) -> List[Dict[str, Any]]:
r"""Process a single text to generate multi-hop QA pairs.
Args:
text (str): The input text to process.
source (str, optional): Source identifier for the text.
(default: :obj:`"user_input"`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
"""
# Convert text to standard format
raw_data = [
{
'text': text,
'source': source,
}
]
# Construct examples
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
examples = constructor.construct_examples(raw_data)
# Manage data
curator = DataCurator(self.config, self.rng)
final_dataset = curator.curate_dataset(examples)
return final_dataset
def process_batch(
self, texts: List[str], sources: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
r"""Process multiple texts in batch to generate multi-hop QA pairs.
Args:
texts (List[str]): List of input texts to process.
sources (Optional[List[str]], optional): List of source
identifiers. (default: :obj:`None`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
Raises:
ValueError: If length of sources doesn't match length of texts.
"""
if sources is None:
sources = ["user_input"] * len(texts)
elif len(sources) != len(texts):
raise ValueError("Length of sources must match length of texts")
raw_data = [
{
'text': text,
'source': source,
}
for text, source in zip(texts, sources)
]
# Construct examples
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
examples = constructor.construct_examples(raw_data)
# Manage data
curator = DataCurator(self.config, self.rng)
final_dataset = curator.curate_dataset(examples)
return final_dataset
class ExampleConstructor:
r"""Constructs training examples from raw text data.
This class handles the construction of training examples by preprocessing
text, extracting information pairs, and generating question-answer pairs.
Attributes:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for QA
generation.
"""
def __init__(
self,
config: ProcessorConfig,
multi_hop_agent: Optional[MultiHopGeneratorAgent] = None,
):
r"""Initialize the ExampleConstructor.
Args:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent], optional):
Agent for generating multi-hop QA pairs. (default: :obj:`None`)
"""
self.config = config
self.multi_hop_agent = multi_hop_agent
def construct_examples(
self, raw_data: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Construct training examples from raw data.
Args:
raw_data (List[Dict[str, Any]]): List of raw data dictionaries
containing text and metadata.
Returns:
List[Dict[str, Any]]: List of constructed examples with QA pairs
and metadata.
"""
logger.info("Starting to construct training examples...")
examples = []
for data in tqdm(raw_data, desc="Constructing examples"):
# 1. Text preprocessing
processed_text = self._preprocess_text(data.get('text', ''))
if not processed_text:
continue
# 2. Generate key information pairs
info_pairs = self._extract_info_pairs(processed_text)
# 3. Construct question-answer pairs
qa_pairs = self._generate_qa_pairs(info_pairs)
# 4. Add metadata
example = {
'text': processed_text,
'qa_pairs': qa_pairs,
'metadata': {
'source': data.get('source', 'unknown'),
'timestamp': data.get('timestamp', ''),
'complexity': self._calculate_complexity(qa_pairs),
},
}
examples.append(example)
logger.info(f"Successfully constructed {len(examples)} examples")
return examples
def _preprocess_text(self, text: str) -> str:
r"""Preprocess input text for example construction.
Args:
text (str): Input text to preprocess.
Returns:
str: Preprocessed text, or empty string if text fails quality
checks.
"""
if not isinstance(text, str):
return ''
# 1. Basic cleaning
text = text.strip()
# 2. Length check
if (
len(text) < self.config.min_length
or len(text) > self.config.max_length
):
return ''
# 3. Quality check
if not self._check_text_quality(text):
return ''
return text
def _check_text_quality(self, text: str) -> bool:
r"""Check the quality of input text.
Args:
text (str): Text to check quality for.
Returns:
bool: True if text passes quality checks, False otherwise.
"""
# 1. Basic quality check
if text.count('.') < 2: # Must have at least 2 sentences
return False
# 2. Special character ratio check
special_char_ratio = len(
[c for c in text if not c.isalnum() and not c.isspace()]
) / len(text)
if special_char_ratio > 0.3: # No more than 30% special characters
return False
return True
def _extract_info_pairs(self, text: str) -> List[Dict[str, Sequence[str]]]:
r"""Extract information pairs and relationships from text.
Args:
text (str): Input text to extract information from.
Returns:
List[Dict[str, Sequence[str]]]: List of dictionaries containing
premise, intermediate, conclusion, and related contexts.
"""
# Split into sentences
sentences = [s.strip() for s in text.split('.') if s.strip()]
info_pairs = []
# Extract combinations of multiple related sentences
for i in range(len(sentences) - 2):
if len(sentences[i]) > 10 and len(sentences[i + 1]) > 10:
info_pairs.append(
{
'premise': sentences[i],
'intermediate': sentences[i + 1],
'conclusion': sentences[i + 2]
if i + 2 < len(sentences)
else '',
'related_contexts': [
s
for j, s in enumerate(sentences)
if j != i and j != i + 1 and len(s) > 10
][:2],
# Limit to 2 additional related contexts
}
)
return info_pairs
def _generate_qa_pairs(
self, info_pairs: List[Dict[str, Sequence[str]]]
) -> List[Dict[str, str]]:
r"""Generate multi-hop question-answer pairs from information pairs.
Args:
info_pairs (List[Dict[str, Sequence[str]]]): List of information
pairs extracted from text.
Returns:
List[Dict[str, str]]: List of generated QA pairs.
"""
qa_pairs = []
for pair in info_pairs:
# 1. Generate multi-hop question-answer pair using AI
if self.multi_hop_agent:
# Construct full context
context = (
f"{pair['premise']}. {pair['intermediate']}."
f" {pair['conclusion']}"
)
response = self.multi_hop_agent.generate_multi_hop_qa(context)
if response:
qa_pairs.append(response.value.dict())
continue
return qa_pairs
def _calculate_complexity(self, qa_pairs: List[Dict[str, Any]]) -> float:
r"""Calculate the complexity score for a set of QA pairs.
Args:
qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate
complexity for.
Returns:
float: Complexity score between 0.0 and 1.0.
"""
if not qa_pairs:
return 0.0
# Calculate complexity based on multiple factors
complexities = []
for qa in qa_pairs:
# 1. Number of reasoning steps
reasoning_steps_count = len(qa.get('reasoning_steps', []))
# 2. Number of supporting facts
supporting_facts_count = len(qa.get('supporting_facts', []))
# 3. Question length
question_length = len(qa.get('question', '').split())
# 4. Answer length
answer_length = len(qa.get('answer', '').split())
# Calculate complexity of a single QA pair
qa_complexity = (
min(reasoning_steps_count / 3, 1.0)
* 0.4 # Weight for reasoning steps
+ min(supporting_facts_count / 3, 1.0)
* 0.3 # Weight for supporting facts
+ min(question_length / 20, 1.0)
* 0.15 # Weight for question length
+ min(answer_length / 50, 1.0) * 0.15
# Weight for answer length
)
complexities.append(qa_complexity)
return sum(complexities) / len(complexities)
class DataCurator:
r"""Manages and curates datasets of multi-hop question-answer pairs.
This class handles dataset management tasks including quality filtering,
complexity filtering, deduplication, and dataset sampling.
Attributes:
config (ProcessorConfig): Configuration for data curation parameters.
rng (random.Random): Random number generator for reproducible sampling.
"""
def __init__(self, config: ProcessorConfig, rng: random.Random):
r"""Initialize the DataCurator.
Args:
config (ProcessorConfig): Configuration for data curation.
rng (random.Random): Random number generator for reproducibility.
"""
self.config = config
self.rng = rng
def curate_dataset(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Manage and curate a dataset through multiple filtering stages.
Args:
examples (List[Dict[str, Any]]): List of examples to curate.
Returns:
List[Dict[str, Any]]: Curated dataset meeting quality criteria.
"""
logger.info("Starting dataset management...")
# 1. Quality filtering
quality_filtered = self._quality_filter(examples)
logger.info(
f"Remaining examples after quality filtering:"
f" {len(quality_filtered)}"
)
# 2. Complexity filtering
complexity_filtered = self._complexity_filter(quality_filtered)
logger.info(
f"Remaining examples after complexity filtering:"
f" {len(complexity_filtered)}"
)
# 3. Deduplication
deduplicated = self._remove_duplicates(complexity_filtered)
logger.info(
f"Remaining examples after deduplication: {len(deduplicated)}"
)
# 4. Sample to target size
final_dataset = self._sample_dataset(deduplicated)
logger.info(f"Final dataset size: {len(final_dataset)}")
return final_dataset
def _quality_filter(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Filter examples based on quality criteria.
Args:
examples (List[Dict[str, Any]]): List of examples to filter.
Returns:
List[Dict[str, Any]]: Examples that pass quality checks.
"""
filtered = []
for example in examples:
# 1. Check QA pair quality
qa_quality = self._check_qa_quality(example.get('qa_pairs', []))
# 2. Check text quality
text_quality = (
len(example.get('text', '').split()) >= 20
) # At least 20 words
if qa_quality and text_quality:
filtered.append(example)
return filtered
def _check_qa_quality(self, qa_pairs: List[Dict[str, str]]) -> bool:
r"""Check the quality of question-answer pairs.
Args:
qa_pairs (List[Dict[str, str]]): List of QA pairs to check.
Returns:
bool: True if QA pairs meet quality criteria, False otherwise.
"""
if not qa_pairs:
return False
for qa in qa_pairs:
# 1. Length check
if (
len(qa.get('question', '')) < 10
or len(qa.get('answer', '')) < 5
):
return False
# 2. QA pair duplication check
if qa.get('question', '') == qa.get('answer', ''):
return False
return True
def _complexity_filter(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Filter examples based on complexity threshold.
Removes examples with complexity scores below the configured threshold.
Args:
examples (List[Dict[str, Any]]): List of examples to filter.
Returns:
List[Dict[str, Any]]: Examples meeting complexity threshold.
"""
return [
example
for example in examples
if example.get('metadata', {}).get('complexity', 0)
>= self.config.complexity_threshold
]
def _remove_duplicates(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Remove duplicate examples from the dataset.
Args:
examples (List[Dict[str, Any]]): List of examples to deduplicate.
Returns:
List[Dict[str, Any]]: Deduplicated examples.
"""
seen = set()
unique_examples = []
for example in examples:
# Use text and QA pair combination as unique identifier
text = example.get('text', '')
qa_str = str(example.get('qa_pairs', []))
identifier = hash(text + qa_str)
if identifier not in seen:
seen.add(identifier)
unique_examples.append(example)
return unique_examples
def _sample_dataset(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Sample examples to match target dataset size.
Args:
examples (List[Dict[str, Any]]): List of examples to sample from.
Returns:
List[Dict[str, Any]]: Sampled dataset of target size or smaller.
"""
if len(examples) <= self.config.dataset_size:
return examples
return self.rng.sample(examples, self.config.dataset_size)

View File

@@ -0,0 +1,93 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, ClassVar, Dict, List, Optional
from pydantic import BaseModel, Field
class ReasoningStep(BaseModel):
r"""A single step in a multi-hop reasoning process.
Attributes:
step (str): The textual description of the reasoning step.
"""
step: str = Field(
..., description="A single step in the reasoning process."
)
class MultiHopQA(BaseModel):
r"""A multi-hop question-answer pair with reasoning steps and supporting
facts.
Attributes:
question (str): The question requiring multi-hop reasoning.
reasoning_steps (List[ReasoningStep]): List of reasoning steps to
answer.
answer (str): The final answer to the question.
supporting_facts (List[str]): List of facts supporting the reasoning.
type (str): The type of question-answer pair.
"""
question: str = Field(
..., description="The question that requires multi-hop reasoning."
)
reasoning_steps: List[ReasoningStep] = Field(
...,
description="The steps involved in reasoning to answer the question.",
)
answer: str = Field(
..., description="The answer to the multi-hop question."
)
supporting_facts: List[str] = Field(
..., description="Facts that support the reasoning and answer."
)
type: str = Field(description="The type of question-answer pair.")
class Config:
json_schema_extra: ClassVar[Dict[str, Any]] = {
"example": {
"question": "What is the capital of France?",
"reasoning_steps": [
{"step": "Identify the country France."},
{"step": "Find the capital city of France."},
],
"answer": "Paris",
"supporting_facts": [
"France is a country in Europe.",
"Paris is the capital city of France.",
],
"type": "multi_hop_qa",
}
}
class ContextPrompt(BaseModel):
r"""A context prompt for generating multi-hop question-answer pairs.
Attributes:
main_context (str): The primary context for generating QA pairs.
related_contexts (Optional[List[str]]): Additional related contexts.
"""
main_context: str = Field(
...,
description="The main context for generating"
" the question-answer pair.",
)
related_contexts: Optional[List[str]] = Field(
default=None,
description="Additional contexts related to the main context.",
)

View File

@@ -0,0 +1,74 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
from pydantic import BaseModel, ConfigDict, Field
from camel.agents.multi_hop_generator_agent import MultiHopGeneratorAgent
class ProcessorConfig(BaseModel):
r"""Data processing configuration class"""
def __repr__(self):
return (
f"ProcessorConfig("
f"seed={self.seed}, min_length={self.min_length}, "
f"max_length={self.max_length}, "
f"complexity_threshold={self.complexity_threshold}, "
f"dataset_size={self.dataset_size}, "
f"use_ai_model={self.use_ai_model}"
f")"
)
model_config = ConfigDict(
validate_assignment=True,
frozen=False,
protected_namespaces=(),
arbitrary_types_allowed=True,
)
seed: int = Field( # Generate a random seed for reproducibility
default_factory=lambda: random.randint(0, 1000),
description="Random seed for reproducibility",
)
min_length: int = Field(
default=50, description="Minimum text length", ge=0
)
max_length: int = Field(
default=512, description="Maximum text length", gt=0
)
complexity_threshold: float = Field(
default=0.5,
description="Complexity threshold for processing",
ge=0.0,
le=1.0,
)
dataset_size: int = Field(
default=1000, description="Target size of the dataset", gt=0
)
use_ai_model: bool = Field(
default=True, description="Whether to use AI model in processing"
)
hop_generating_agent: MultiHopGeneratorAgent = Field(
default_factory=lambda: MultiHopGeneratorAgent(),
description="Agent for generating multi-hop text",
)

View File

@@ -0,0 +1,23 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base import BaseDatasetManager
from .huggingface import HuggingFaceDatasetManager
from .models import Record
__all__ = [
"BaseDatasetManager",
"Record",
"HuggingFaceDatasetManager",
]

136
camel/datahubs/base.py Normal file
View File

@@ -0,0 +1,136 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from abc import ABC, abstractmethod
from typing import Any, List
from camel.datahubs.models import Record
class BaseDatasetManager(ABC):
r"""Abstract base class for dataset managers."""
@abstractmethod
def create_dataset(self, name: str, **kwargs: Any) -> str:
r"""Creates a new dataset.
Args:
name (str): The name of the dataset.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
pass
@abstractmethod
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default::obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
pass
@abstractmethod
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
@abstractmethod
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass
# New method for record deletion
@abstractmethod
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
(default::obj:`"records/records.json"`)
kwargs (Any): Additional keyword arguments.
"""
pass

View File

@@ -0,0 +1,444 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import os
import tempfile
from typing import Any, List, Optional
from camel.datahubs.base import BaseDatasetManager
from camel.datahubs.models import Record
from camel.logger import get_logger
from camel.types import HuggingFaceRepoType
from camel.utils import api_keys_required, dependencies_required
logger = get_logger(__name__)
class HuggingFaceDatasetManager(BaseDatasetManager):
r"""A dataset manager for Hugging Face datasets. This class provides
methods to create, add, update, delete, and list records in a dataset on
the Hugging Face Hub.
Args:
token (str): The Hugging Face API token. If not provided, the token
will be read from the environment variable `HF_TOKEN`.
"""
@api_keys_required(
[
("token", "HF_TOKEN"),
]
)
@dependencies_required('huggingface_hub')
def __init__(self, token: Optional[str] = None):
from huggingface_hub import HfApi
self._api_key = token or os.getenv("HF_TOKEN")
self.api = HfApi(token=self._api_key)
def create_dataset_card(
self,
dataset_name: str,
description: str,
license: Optional[str] = None,
version: Optional[str] = None,
tags: Optional[List[str]] = None,
authors: Optional[List[str]] = None,
size_category: Optional[List[str]] = None,
language: Optional[List[str]] = None,
task_categories: Optional[List[str]] = None,
content: Optional[str] = None,
) -> None:
r"""Creates and uploads a dataset card to the Hugging Face Hub in YAML
format.
Args:
dataset_name (str): The name of the dataset.
description (str): A description of the dataset.
license (str): The license of the dataset. (default: :obj:`None`)
version (str): The version of the dataset. (default: :obj:`None`)
tags (list): A list of tags for the dataset.(default: :obj:`None`)
authors (list): A list of authors of the dataset. (default:
:obj:`None`)
size_category (list): A size category for the dataset. (default:
:obj:`None`)
language (list): A list of languages the dataset is in. (default:
:obj:`None`)
task_categories (list): A list of task categories. (default:
:obj:`None`)
content (str): Custom markdown content that the user wants to add
to the dataset card. (default: :obj:`None`)
"""
import yaml
metadata = {
"license": license,
"authors": authors,
"task_categories": task_categories,
"language": language,
"tags": tags,
"pretty_name": dataset_name,
"size_categories": size_category,
"version": version,
"description": description,
}
# Remove keys with None values
metadata = {k: v for k, v in metadata.items() if v}
card_content = (
"---\n"
+ yaml.dump(metadata, default_flow_style=False, allow_unicode=True)
+ "\n---"
)
if content:
card_content += f"\n\n# Additional Information\n{content}\n"
self._upload_file(
file_content=card_content,
dataset_name=dataset_name,
filepath="README.md",
file_type="md",
)
def create_dataset(
self, name: str, private: bool = False, **kwargs: Any
) -> str:
r"""Creates a new dataset on the Hugging Face Hub.
Args:
name (str): The name of the dataset.
private (bool): Whether the dataset should be private. defaults to
False.
kwargs (Any): Additional keyword arguments.
Returns:
str: The URL of the created dataset.
"""
from huggingface_hub.errors import RepositoryNotFoundError
try:
self.api.repo_info(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except RepositoryNotFoundError:
self.api.create_repo(
repo_id=name,
repo_type=HuggingFaceRepoType.DATASET.value,
private=private,
)
return f"https://huggingface.co/datasets/{name}"
def list_datasets(
self, username: str, limit: int = 100, **kwargs: Any
) -> List[str]:
r"""Lists all datasets for the current user.
Args:
username (str): The username of the user whose datasets to list.
limit (int): The maximum number of datasets to list.
(default: :obj:`100`)
kwargs (Any): Additional keyword arguments.
Returns:
List[str]: A list of dataset ids.
"""
try:
return [
dataset.id
for dataset in self.api.list_datasets(
author=username, limit=limit, **kwargs
)
]
except Exception as e:
logger.error(f"Error listing datasets: {e}")
return []
def delete_dataset(self, dataset_name: str, **kwargs: Any) -> None:
r"""Deletes a dataset from the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset to delete.
kwargs (Any): Additional keyword arguments.
"""
try:
self.api.delete_repo(
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"Dataset '{dataset_name}' deleted successfully.")
except Exception as e:
logger.error(f"Error deleting dataset '{dataset_name}': {e}")
raise
def add_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Adds records to a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to add to the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset already has a records file.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if existing_records:
raise ValueError(
f"Dataset '{filepath}' already exists. "
f"Use `update_records` to modify."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def update_records(
self,
dataset_name: str,
records: List[Record],
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Updates records in a dataset on the Hugging Face Hub.
Args:
dataset_name (str): The name of the dataset.
records (List[Record]): A list of records to update in the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to update
records in.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
logger.warning(
f"Dataset '{dataset_name}' does not have existing "
"records. Adding new records."
)
self._upload_records(
records=records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
return
old_dict = {record.id: record for record in existing_records}
new_dict = {record.id: record for record in records}
merged_dict = old_dict.copy()
merged_dict.update(new_dict)
self._upload_records(
records=list(merged_dict.values()),
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def delete_record(
self,
dataset_name: str,
record_id: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> None:
r"""Deletes a record from the dataset.
Args:
dataset_name (str): The name of the dataset.
record_id (str): The ID of the record to delete.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Raises:
ValueError: If the dataset does not have an existing file to delete
records from.
"""
existing_records = self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
if not existing_records:
raise ValueError(
f"Dataset '{dataset_name}' does not have an existing file to "
f"delete records from."
)
filtered_records = [
record for record in existing_records if record.id != record_id
]
self._upload_records(
records=filtered_records,
dataset_name=dataset_name,
filepath=filepath,
**kwargs,
)
def list_records(
self,
dataset_name: str,
filepath: str = "records/records.json",
**kwargs: Any,
) -> List[Record]:
r"""Lists all records in a dataset.
Args:
dataset_name (str): The name of the dataset.
filepath (str): The path to the file containing the records.
kwargs (Any): Additional keyword arguments.
Returns:
List[Record]: A list of records in the dataset.
"""
return self._download_records(
dataset_name=dataset_name, filepath=filepath, **kwargs
)
def _download_records(
self, dataset_name: str, filepath: str, **kwargs: Any
) -> List[Record]:
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
try:
downloaded_file_path = hf_hub_download(
repo_id=dataset_name,
filename=filepath,
repo_type=HuggingFaceRepoType.DATASET.value,
token=self._api_key,
**kwargs,
)
with open(downloaded_file_path, "r") as f:
records_data = json.load(f)
return [Record(**record) for record in records_data]
except EntryNotFoundError:
logger.info(f"No records found for dataset '{dataset_name}'.")
return []
except Exception as e:
logger.error(f"Error downloading or processing records: {e}")
raise e
def _upload_records(
self,
records: List[Record],
dataset_name: str,
filepath: str,
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
delete=False, mode="w", newline="", encoding="utf-8"
) as f:
json.dump(
[
record.model_dump(exclude_defaults=True)
for record in records
],
f,
ensure_ascii=False,
)
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
except Exception as e:
logger.error(f"Error uploading records file: {e}")
raise
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
def _upload_file(
self,
file_content: str,
dataset_name: str,
filepath: str,
file_type: str = "json",
**kwargs: Any,
):
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=f".{file_type}"
) as f:
if file_type == "json":
if isinstance(file_content, str):
try:
json_content = json.loads(file_content)
except json.JSONDecodeError:
raise ValueError(
"Invalid JSON string provided for file_content."
)
else:
try:
json.dumps(file_content, ensure_ascii=False)
json_content = file_content
except (TypeError, ValueError):
raise ValueError(
"file_content is not JSON serializable."
)
json.dump(json_content, f, ensure_ascii=False)
elif file_type == "md" or file_type == "txt":
f.write(file_content)
else:
raise ValueError(f"Unsupported file type: {file_type}")
temp_file_path = f.name
try:
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=filepath,
repo_id=dataset_name,
repo_type=HuggingFaceRepoType.DATASET.value,
**kwargs,
)
logger.info(f"File uploaded successfully: {filepath}")
except Exception as e:
logger.error(f"Error uploading file: {e}")
raise
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

24
camel/datahubs/models.py Normal file
View File

@@ -0,0 +1,24 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict
class Record(BaseModel):
id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
content: Optional[Dict[str, Any]] = None
model_config = ConfigDict(extra="allow")

View File

@@ -0,0 +1,26 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .base_generator import BaseGenerator
from .few_shot_generator import FewShotGenerator
from .models import DataPoint
from .self_instruct_generator import SelfInstructGenerator
from .static_dataset import StaticDataset
__all__ = [
"BaseGenerator",
"DataPoint",
"FewShotGenerator",
"StaticDataset",
"SelfInstructGenerator",
]

View File

@@ -0,0 +1,292 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import abc
import asyncio
import json
import random
from pathlib import Path
from typing import Any, Dict, List, Union
from pydantic import ValidationError
from torch.utils.data import IterableDataset
from camel.logger import get_logger
from .models import DataPoint
logger = get_logger(__name__)
class BaseGenerator(abc.ABC, IterableDataset):
r"""Abstract base class for data generators.
This class defines the interface for generating synthetic datapoints.
Concrete implementations should provide specific generation strategies.
"""
def __init__(
self,
seed: int = 42,
buffer: int = 20,
cache: Union[str, Path, None] = None,
data_path: Union[str, Path, None] = None,
**kwargs,
):
r"""Initialize the base generator.
Args:
seed (int): Random seed for reproducibility. (default: :obj:`42`)
buffer (int): Amount of DataPoints to be generated when the
iterator runs out of DataPoints in data. (default: :obj:`20`)
cache (Union[str, Path, None]): Optional path to save generated
datapoints during iteration. If None is provided, datapoints
will be discarded every 100 generations.
data_path (Union[str, Path, None]): Optional path to a JSONL file
to initialize the dataset from.
**kwargs: Additional generator parameters.
"""
self._rng = random.Random(seed)
self.cache = Path(cache) if cache else None
self._buffer = buffer
self._data: List[DataPoint] = []
self._batch_to_save: List[DataPoint] = []
if data_path:
file_path = Path(data_path)
raw_data = self._init_from_jsonl(file_path)
try:
data_points = [DataPoint(**item) for item in raw_data]
self._data.extend(data_points)
except ValidationError as e:
raise ValueError(
f"Failed to create DataPoint from JSONL data: {e}"
)
@abc.abstractmethod
async def generate_new(self, n: int, **kwargs) -> None:
r"""Generate n new datapoints and append them to self._data.
Subclass implementations must generate the specified number of
datapoints and append them directly to the `self._data` list.
This method should not return the datapoints; the iterator
relies on `self._data` being populated.
Args:
n (int): Number of datapoints to generate and append.
**kwargs: Additional generation parameters.
Returns:
None: This method should not return anything.
Example:
```python
async def generate_new(self, n: int, **kwargs) -> None:
new_points = [DataPoint(...) for _ in range(n)]
self._data.extend(new_points)
```
"""
pass
def __aiter__(self):
r"""Async iterator that yields datapoints dynamically.
If a `data_path` was provided during initialization, those datapoints
are yielded first. When self._data is empty, 20 new datapoints
are generated. Every 100 yields, the batch is appended to the
JSONL file or discarded if `cache` is None.
Yields:
DataPoint: A single datapoint.
"""
async def generator():
while True:
if not self._data:
await self.generate_new(self._buffer)
datapoint = self._data.pop(0)
yield datapoint
self._batch_to_save.append(datapoint)
if len(self._batch_to_save) == 100:
if self.cache:
with self.cache.open("a", encoding="utf-8") as f:
for dp in self._batch_to_save:
json.dump(dp.to_dict(), f, ensure_ascii=False)
f.write("\n")
self._batch_to_save = []
return generator()
def __iter__(self):
r"""Synchronous iterator for PyTorch IterableDataset compatibility.
If a `data_path` was provided during initialization, those datapoints
are yielded first. When self._data is empty, 20 new datapoints
are generated. Every 100 yields, the batch is appended to the
JSONL file or discarded if `cache` is None.
Yields:
DataPoint: A single datapoint.
"""
try:
if asyncio.get_event_loop().is_running():
raise RuntimeError(
"Cannot use synchronous iteration (__iter__) in an async "
"context; use 'async for' with __aiter__ instead"
)
except RuntimeError as e:
if "no running event loop" not in str(e):
raise
while True:
if not self._data:
asyncio.run(self.generate_new(self._buffer))
datapoint = self._data.pop(0)
yield datapoint
self._batch_to_save.append(datapoint)
if len(self._batch_to_save) == 100:
if self.cache:
with self.cache.open("a", encoding="utf-8") as f:
for dp in self._batch_to_save:
json.dump(dp.to_dict(), f, ensure_ascii=False)
f.write("\n")
self._batch_to_save = []
def sample(self) -> DataPoint:
r"""Returns the next datapoint from the current dataset
synchronously.
Raises:
RuntimeError: If called in an async context.
Returns:
DataPoint: The next DataPoint.
Note:
This method is intended for synchronous contexts.
Use 'async_sample' in asynchronous contexts to
avoid blocking or runtime errors.
"""
try:
if asyncio.get_event_loop().is_running():
raise RuntimeError(
"Cannot use synchronous sampling (sample) "
"in an async context; use async_sample instead"
)
except RuntimeError as e:
if "no running event loop" not in str(e):
raise
return next(iter(self))
async def async_sample(self) -> DataPoint:
r"""Returns the next datapoint from the current dataset asynchronously.
Returns:
DataPoint: The next datapoint.
Note:
This method is intended for asynchronous contexts. Use 'sample'
in synchronous contexts.
"""
async_iter = self.__aiter__()
return await async_iter.__anext__()
def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
r"""Saves the generated datapoints to a JSONL (JSON Lines) file.
Each datapoint is stored as a separate JSON object on a new line.
Args:
file_path (Union[str, Path]): Path to save the JSONL file.
Raises:
ValueError: If no datapoints have been generated.
IOError: If there is an issue writing to the file.
Notes:
- Uses `self._data`, which contains the generated datapoints.
- Appends to the file if it already exists.
- Ensures compatibility with large datasets by using JSONL format.
"""
if not self._data:
raise ValueError("Dataset is empty. No data to save.")
file_path = Path(file_path)
try:
with file_path.open("a", encoding="utf-8") as f:
for datapoint in self._data:
json.dump(datapoint.to_dict(), f, ensure_ascii=False)
f.write("\n")
logger.info(f"Dataset saved successfully to {file_path}")
except IOError as e:
logger.error(f"Error writing to file {file_path}: {e}")
raise
def flush(self, file_path: Union[str, Path]) -> None:
r"""Flush the current data to a JSONL file and clear the data.
Args:
file_path (Union[str, Path]): Path to save the JSONL file.
Notes:
- Uses `save_to_jsonl` to save `self._data`.
"""
self.save_to_jsonl(file_path)
self._data = []
logger.info(f"Data flushed to {file_path} and cleared from the memory")
def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
r"""Load and parse a dataset from a JSONL file.
Args:
file_path (Path): Path to the JSONL file.
Returns:
List[Dict[str, Any]]: A list of datapoint dictionaries.
Raises:
FileNotFoundError: If the specified JSONL file does not exist.
ValueError: If a line contains invalid JSON or is not a dictionary.
"""
if not file_path.exists():
raise FileNotFoundError(f"JSONL file not found: {file_path}")
raw_data = []
logger.debug(f"Loading JSONL from {file_path}")
with file_path.open('r', encoding='utf-8') as f:
for line_number, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue # Skip blank lines
try:
record = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON on line {line_number} "
f"in file {file_path}: {e}"
)
if not isinstance(record, dict):
raise ValueError(
f"Expected a dictionary at line {line_number}, "
f"got {type(record).__name__}"
)
raw_data.append(record)
logger.info(
f"Successfully loaded {len(raw_data)} items from {file_path}"
)
return raw_data

View File

@@ -0,0 +1,282 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
from datetime import datetime
from typing import List
from pydantic import BaseModel, Field, ValidationError
from camel.agents import ChatAgent
from camel.logger import get_logger
from camel.models.base_model import BaseModelBackend
from camel.verifiers import BaseVerifier
from .base_generator import BaseGenerator
from .models import DataPoint
from .static_dataset import StaticDataset
logger = get_logger(__name__)
SYSTEM_PROMPT = """**You are an advanced data generation assistant.**
Your goal is to generate high-quality synthetic data points based on
provided examples. Your output must be well-structured,
logically sound, and formatted correctly.
**Instructions:**
1. **Follow the Structure**
Each data point must include:
- **Question**: A clear, well-formed query.
- **Rationale**: A step-by-step, executable reasoning process ending
with `print(final_answer)`.
- **Final Answer**: The correct, concise result.
2. **Ensure Logical Consistency**
- The `rationale` must be code that runs correctly.
- The `final_answer` should match the printed output.
3. **Output Format (Strict)**
```
Question: [Generated question]
Rationale: [Code that solves the question, ending in a print statement,
outputting the answer.]
Final Answer: [The Final Answer]
**Now, generate a new data point based on the given examples.**
"""
class FewShotGenerator(BaseGenerator):
r"""A generator for creating synthetic datapoints using few-shot learning.
This class leverages a seed dataset, an agent, and a verifier to generate
new synthetic datapoints on demand through few-shot prompting.
"""
def __init__(
self,
seed_dataset: StaticDataset,
verifier: BaseVerifier,
model: BaseModelBackend,
seed: int = 42,
**kwargs,
):
r"""Initialize the few-shot generator.
Args:
seed_dataset (StaticDataset): Validated static dataset to
use for examples.
verifier (BaseVerifier): Verifier to validate generated content.
model (BaseModelBackend): The underlying LLM that the generating
agent will be initiated with.
seed (int): Random seed for reproducibility. (default: :obj:`42`)
**kwargs: Additional generator parameters.
"""
super().__init__(seed=seed, **kwargs)
self.seed_dataset = seed_dataset
try:
self._validate_seed_dataset()
except Exception:
raise RuntimeError("Seed Data does not follow Datapoint format")
self.verifier = verifier
self.agent = ChatAgent(system_message=SYSTEM_PROMPT, model=model)
# TODO: Validate that seed dataset contains rationale
def _validate_seed_dataset(self) -> None:
pass
def _construct_prompt(self, examples: List[DataPoint]) -> str:
r"""Construct a prompt for generating new datapoints
using a fixed sample of examples from the seed dataset.
Args:
examples (List[DataPoint]): Examples to include in the prompt.
Returns:
str: Formatted prompt with examples.
"""
prompt = (
"Generate a new datapoint similar to the following examples:\n\n"
)
for i, example in enumerate(examples, 1):
prompt += f"Example {i}:\n"
prompt += f"Question: {example.question}\n"
if example.rationale is not None:
prompt += f"Rationale: {example.rationale}\n"
else:
prompt += "Rationale: None\n"
prompt += f"Final Answer: {example.final_answer}\n\n"
prompt += "New datapoint:"
return prompt
async def generate_new(
self,
n: int,
max_retries: int = 10,
num_examples: int = 3,
**kwargs,
) -> None:
r"""Generates and validates `n` new datapoints through
few-shot prompting, with a retry limit.
Steps:
1. Samples examples from the seed dataset.
2. Constructs a prompt using the selected examples.
3. Uses an agent to generate a new datapoint,
consisting of a question and code to solve the question.
4. Executes code using a verifier to get pseudo ground truth.
5. Stores valid datapoints in memory.
Args:
n (int): Number of valid datapoints to generate.
max_retries (int): Maximum number of retries before stopping.
(default: :obj:`10`)
num_examples (int): Number of examples to sample from the
seed dataset for few shot prompting.
(default: :obj:`3`)
**kwargs: Additional generation parameters.
Returns:
List[DataPoint]: A list of newly generated valid datapoints.
Raises:
TypeError: If the agent's output is not a dictionary (or does not
match the expected format).
KeyError: If required keys are missing from the response.
AttributeError: If the verifier response lacks attributes.
ValidationError: If a datapoint fails schema validation.
RuntimeError: If retries are exhausted before `n` valid datapoints
are generated.
Notes:
- Retries on validation failures until `n` valid datapoints exist
or `max_retries` is reached, whichever comes first.
- If retries are exhausted before reaching `n`, a `RuntimeError`
is raised.
- Metadata includes a timestamp for tracking datapoint creation.
"""
valid_data_points: List[DataPoint] = []
retries = 0
while len(valid_data_points) < n and retries < max_retries:
try:
examples = [
self.seed_dataset.sample() for _ in range(num_examples)
]
prompt = self._construct_prompt(examples)
# Create a simplified version of DataPoint that omits metadata
# because agent.step's response_format parameter doesn't
# support type Dict[str, Any]
class DataPointSimplified(BaseModel):
question: str = Field(
description="The primary question or issue to "
"be addressed."
)
final_answer: str = Field(description="The final answer.")
rationale: str = Field(
description="Logical reasoning or explanation "
"behind the answer."
)
try:
agent_output = (
self.agent.step(
prompt, response_format=DataPointSimplified
)
.msgs[0]
.parsed
)
assert isinstance(agent_output, DataPointSimplified)
self.agent.reset()
except (TypeError, KeyError) as e:
logger.warning(
f"Agent output issue: {e}, retrying... "
f"({retries + 1}/{max_retries})"
)
retries += 1
continue
rationale = agent_output.rationale
if not isinstance(rationale, str):
raise TypeError(f"Rationale {rationale} is not a string.")
try:
verifier_response = await asyncio.wait_for(
self.verifier.verify(
solution=rationale,
reference_answer=None,
),
timeout=180,
)
if not verifier_response or not verifier_response.result:
raise ValueError(
"Verifier unsuccessful, response: "
f"{verifier_response}"
)
except (ValueError, AttributeError, asyncio.TimeoutError) as e:
error_msg = (
"Verifier timeout"
if isinstance(e, asyncio.TimeoutError)
else f"Verifier issue: {e}"
)
logger.warning(
f"{error_msg}, retrying... "
f"({retries + 1}/{max_retries})"
)
retries += 1
continue
try:
new_datapoint = DataPoint(
question=agent_output.question,
rationale=rationale,
final_answer=verifier_response.result,
metadata={
"synthetic": str(True),
"created": datetime.now().isoformat(),
"generator": "few_shot",
"shots": [e.to_dict() for e in examples],
},
)
except ValidationError as e:
logger.warning(
f"Datapoint validation failed: {e}, "
f"retrying... ({retries + 1}/{max_retries})"
)
retries += 1
continue
valid_data_points.append(new_datapoint)
except Exception as e:
logger.warning(
f"Unexpected error: {e}, retrying..."
f" ({retries + 1}/{max_retries})"
)
retries += 1
if len(valid_data_points) < n:
raise RuntimeError(
f"Failed to generate {n} valid datapoints "
f"after {max_retries} retries."
)
# Thread-safe way to extend the data list
async with asyncio.Lock():
self._data.extend(valid_data_points)

61
camel/datasets/models.py Normal file
View File

@@ -0,0 +1,61 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class DataPoint(BaseModel):
r"""A single data point in the dataset.
Attributes:
question (str): The primary question or issue to be addressed.
final_answer (str): The final answer.
rationale (Optional[str]): Logical reasoning or explanation behind the
answer. (default: :obj:`None`)
metadata (Optional[Dict[str, Any]]): Additional metadata about the data
point. (default: :obj:`None`)
"""
question: str = Field(
..., description="The primary question or issue to be addressed."
)
final_answer: str = Field(..., description="The final answer.")
rationale: Optional[str] = Field(
default=None,
description="Logical reasoning or explanation behind the answer.",
)
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Additional metadata about the data point."
)
def to_dict(self) -> Dict[str, Any]:
r"""Convert DataPoint to a dictionary.
Returns:
Dict[str, Any]: Dictionary representation of the DataPoint.
"""
return self.dict()
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
r"""Create a DataPoint from a dictionary.
Args:
data (Dict[str, Any]): Dictionary containing DataPoint fields.
Returns:
DataPoint: New DataPoint instance.
"""
return cls(**data)

View File

@@ -0,0 +1,415 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
import random
from datetime import datetime
from typing import Iterable, List, Optional, cast
from pydantic import BaseModel, Field, ValidationError
from camel.agents import ChatAgent
from camel.logger import get_logger
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType
from camel.verifiers import BaseVerifier
from .base_generator import BaseGenerator
from .models import DataPoint
from .static_dataset import StaticDataset
logger = get_logger(__name__)
DEFAULT_INSTRUCTION_SYSTEM_PROMPT = """
You are a high-capacity instruction generation assistant.
Your task is to generate a **new, creative, and challenging question** based on
several examples.
These examples may cover different domains or styles, but your goal is to:
- **Understand their specific patterns** in structure, and complexity;
- **Combine and synthesize** ideas from multiple examples, rather than copying
or lightly editing any single one;
- **Intelligently integrate** multiple reasoning steps, constraints, or
concepts into a single, coherent question;
- Ensure the new question is **non-trivial** and requires deep thinking or
multi-step reasoning.
**Guidelines:**
- Use the examples as inspiration for format, depth, and tone.
- Your new question should be self-contained, logically sound, and answerable.
- Do not repeat exact phrasings or create shallow combinations; instead,
produce something meaningfully new.
- Avoid open-ended or subjective questions that depend on personal opinions or
discussion.
- The generated question must have a **clear, objective, and verifiable
answer**.
- Aim for increased depth or novelty through subtle combination or
transformation.
- Keep the final output to a **single unified question** with one clear answer,
not a multi-part task.
**Output Format (strict):**
```
Question: [Generated question]
```
"""
DEFAULT_RATIONALE_SYSTEM_PROMPT = """You are an advanced Python code assistant.
Your task is to **solve the given question by writing Python code only**,
without any explanation or natural language output.
The code must compute the answer **programmatically**, not by hardcoding or
guessing the result.
**Rules:**
- Use Python code to perform the actual computation.
- Use {package_list} to solve the problem. Do not import any other libraries.
- **Do not hardcode the final answer** (e.g., avoid writing `print(1/2)` unless
that value is computed).
- The result must be obtained through valid computation logic in code.
- Do not include explanations. Output code only.
- The entire code must be wrapped in triple backticks:
```
[Your Python code here]
```
Now, solve the following question using Python. Only output the code:
"""
class SelfInstructGenerator(BaseGenerator):
r"""A generator for creating synthetic datapoints using self-instruct.
It utilizes both a human-provided dataset (seed_dataset) and generated
machine instructions (machine_instructions) to produce new, synthetic
datapoints that include a question, a computed rationale (code), and a
final answer (from a verifier).
"""
class QuestionSchema(BaseModel):
r"""Schema for the generated question.
Attributes:
question (str): The question generated by the model.
"""
question: str = Field(description="The question generated")
class RationaleSchema(BaseModel):
r"""Schema for the generated rationale code.
Attributes:
code (str): The generated code without any formatting.
"""
code: str = Field(
description="The generated code without any formatting"
)
def __init__(
self,
seed_dataset: StaticDataset,
verifier: BaseVerifier,
instruction_agent: Optional[ChatAgent] = None,
rationale_agent: Optional[ChatAgent] = None,
seed: int = 42,
**kwargs,
):
r"""Initialize the self-instruct generator.
Args:
seed_dataset (StaticDataset): Dataset containing seed instructions.
verifier (BaseVerifier): Verifier instance to validate generated
solutions.
instruction_agent (Optional[ChatAgent]): Agent for generating
instructions. If not provided, a default agent will be created.
rationale_agent (Optional[ChatAgent]): Agent for generating
rationales. If not provided, a default agent will be created.
seed (int): Random seed for reproducibility. (default: :obj:`42`)
**kwargs: Additional keyword arguments passed to the BaseGenerator.
"""
super().__init__(seed=seed, **kwargs)
self.seed_dataset = seed_dataset
self.verifier = verifier
# extract packages from verifier
self.packages: List[str] = getattr(
self.verifier, "required_packages", []
)
# create default agents if not provided
self.instruction_agent = (
instruction_agent or self.default_instruction_agent()
)
self.rationale_agent = (
rationale_agent or self.default_rationale_agent()
)
# Extract questions from the seed dataset as human_instructions
self.human_instructions: List[str] = [
dp.question
for dp in list(cast(Iterable[DataPoint], self.seed_dataset))
]
self.machine_instructions: List[DataPoint] = []
# Create an instance-level lock for thread-safe updates to _data
self._lock = asyncio.Lock()
self._data = [] # Storage for generated DataPoint instances
def default_instruction_agent(self) -> ChatAgent:
r"""Create the default instruction generation agent.
This agent is configured with a moderate temperature setting to
encourage creative and diverse instruction generation behavior.
Returns:
ChatAgent: An agent with the default instruction prompt.
"""
model = ModelFactory.create(
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict={"temperature": 0.7},
)
return ChatAgent(
DEFAULT_INSTRUCTION_SYSTEM_PROMPT,
model=model,
)
def default_rationale_agent(self) -> ChatAgent:
r"""Create the default rationale generation agent.
This agent is configured with a deterministic (zero temperature)
setting to ensure consistent and precise rationale generation based on
a given instruction and package list.
Returns:
ChatAgent: An agent with the rationale prompt
"""
model = ModelFactory.create(
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict={"temperature": 0.0},
)
return ChatAgent(
DEFAULT_RATIONALE_SYSTEM_PROMPT.format(package_list=self.packages),
model=model,
)
@staticmethod
def format_support_block(dp: DataPoint) -> str:
r"""Format a DataPoint into a few-shot example block.
Args:
dp (DataPoint): A data point.
Returns:
str: A formatted string containing the question and its
corresponding code block in Markdown-style Python format.
"""
support_q = dp.question.strip()
support_code = dp.rationale.strip() if dp.rationale else ""
return (
f"Question:\n{support_q}\n\n"
"Code:\n"
"```python\n"
f"{support_code}\n"
"```"
)
def generate_new_instruction(
self,
agent: ChatAgent,
support_human_dps: list[DataPoint],
support_machine_dps: list[DataPoint],
) -> str:
r"""Generate a new instruction using self-instruct prompting.
Args:
agent (ChatAgent): The agent to use for generating the instruction.
support_human_dps (list[DataPoint]): List of human examples to
sample.
support_machine_dps (list[DataPoint]): List of machine examples to
sample.
Returns:
str: The newly generated question.
"""
human_sample = [dp.question for dp in list(support_human_dps)]
machine_sample = [dp.question for dp in list(support_machine_dps)]
few_shot_examples = human_sample + machine_sample
# Build the prompt using the few-shot examples
prompt = "Below are some question examples:\n\n"
for idx, instr in enumerate(few_shot_examples, start=1):
prompt += f"Question {idx}: {instr}\n"
prompt += f"Question {len(few_shot_examples) + 1}:\n"
prompt += "Now generate a new question based on the given examples.\n"
question_template = f"Question: {prompt}"
response = cast(
SelfInstructGenerator.QuestionSchema,
agent.step(question_template, response_format=self.QuestionSchema)
.msgs[0]
.parsed,
)
return response.question
def generate_rationale(
self,
question: str,
agent: Optional[ChatAgent] = None,
support_human_dps: Optional[list[DataPoint]] = None,
) -> str:
r"""Generate rationale code (solution) for the given question.
Args:
question (str): The question to be solved.
agent (Optional[ChatAgent]): The agent to use for generating the
rationale. If None is provided, the default rationale agent
will be used. (default: :obj:`None`)
support_human_dps (Optional[list[DataPoint]]): List of human
examples to sample. (default: :obj:`None`)
Returns:
str: The generated code solution as a string.
"""
# Build few-shot example prompt
few_shot_prompt = ""
if support_human_dps:
few_shot_examples = [
self.format_support_block(dp) for dp in support_human_dps
]
few_shot_prompt += "Below are example questions and solutions:\n\n"
few_shot_prompt += "\n\n".join(few_shot_examples)
few_shot_prompt += f"\n\nWrite code to solve the question:\n{question}"
response = cast(
SelfInstructGenerator.RationaleSchema,
(agent or self.default_rationale_agent())
.step(few_shot_prompt, response_format=self.RationaleSchema)
.msgs[0]
.parsed,
)
return response.code
async def generate_new(
self,
n: int,
max_retries: int = 10,
human_sample_count: int = 3,
machine_sample_count: int = 1,
**kwargs,
) -> None:
r"""Generates and validates `n` new datapoints through
self-instruct prompting, with a retry limit.
Args:
n (int): The number of valid datapoints to generate.
max_retries (int): Maximum number of retries before stopping.
(default: :obj:`10`)
human_sample_count (int): Number of human examples to sample.
(default: :obj:`3`)
machine_sample_count (int): Number of machine examples to sample.
(default: :obj:`1`)
**kwargs: Additional keyword arguments.
Notes:
- Retries on validation failures until `n` valid datapoints exist
or `max_retries` is reached, whichever comes first.
- If retries are exhausted before reaching `n`, a `RuntimeError`
is raised.
- Metadata includes a timestamp for tracking datapoint creation.
"""
valid_data_points: list[DataPoint] = []
retries = 0
while len(valid_data_points) < n and retries < max_retries:
try:
human_dps_list = list(cast(List[DataPoint], self.seed_dataset))
support_human_dps = random.sample(
human_dps_list,
min(human_sample_count, len(human_dps_list)),
)
machine_dps_list = list(self.machine_instructions)
support_machine_dps = []
if machine_dps_list and machine_sample_count > 0:
support_machine_dps = random.sample(
machine_dps_list,
min(machine_sample_count, len(machine_dps_list)),
)
question = self.generate_new_instruction(
self.instruction_agent,
support_human_dps,
support_machine_dps,
)
rationale = self.generate_rationale(
question, self.rationale_agent, support_human_dps
)
if not isinstance(rationale, str):
raise TypeError(f"Rationale {rationale} is not a string.")
try:
verifier_response = await self.verifier.verify(
solution=rationale,
reference_answer=None,
)
if not verifier_response or not verifier_response.result:
raise ValueError(
"Verifier unsuccessful, response: "
f"{verifier_response}"
)
except (ValueError, AttributeError) as e:
logger.warning(
f"Verifier issue: {e}, "
f"retrying... ({retries + 1}/{max_retries})"
)
retries += 1
continue
try:
new_datapoint = DataPoint(
question=question,
rationale=rationale,
final_answer=verifier_response.result,
metadata={
"synthetic": str(True),
"created": datetime.now().isoformat(),
"generator": "self_instruct",
},
)
except ValidationError as e:
logger.warning(
f"Datapoint validation failed: {e}, "
f"retrying... ({retries + 1}/{max_retries})"
)
retries += 1
continue
valid_data_points.append(new_datapoint)
except Exception as e:
logger.warning(
f"Unexpected error: {e}, retrying..."
f" ({retries + 1}/{max_retries})"
)
retries += 1
if len(valid_data_points) < n:
raise RuntimeError(
f"Failed to generate {n} valid datapoints "
f"after {max_retries} retries."
)
async with self._lock:
self._data.extend(valid_data_points)

View File

@@ -0,0 +1,400 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import random
from pathlib import Path
from typing import (
Any,
Dict,
List,
Optional,
Sized,
Union,
)
from datasets import Dataset as HFDataset
from pydantic import ValidationError
from torch.utils.data import Dataset
from camel.logger import get_logger
from .models import DataPoint
logger = get_logger(__name__)
class StaticDataset(Dataset):
r"""A static dataset containing a list of datapoints.
Ensures that all items adhere to the DataPoint schema.
This dataset extends :obj:`Dataset` from PyTorch and should
be used when its size is fixed at runtime.
This class can initialize from Hugging Face Datasets,
PyTorch Datasets, JSON file paths, or lists of dictionaries,
converting them into a consistent internal format.
"""
def __init__(
self,
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
seed: int = 42,
min_samples: int = 1,
strict: bool = False,
**kwargs,
):
r"""Initialize the static dataset and validate integrity.
Args:
data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
Input data, which can be one of the following:
- A Hugging Face Dataset (:obj:`HFDataset`).
- A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
- A :obj:`Path` object representing a JSON or JSONL file.
- A list of dictionaries with :obj:`DataPoint`-compatible
fields.
seed (int): Random seed for reproducibility.
(default: :obj:`42`)
min_samples (int): Minimum required number of samples.
(default: :obj:`1`)
strict (bool): Whether to raise an error on invalid
datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
(default: :obj:`False`)
**kwargs: Additional dataset parameters.
Raises:
TypeError: If the input data type is unsupported.
ValueError: If the dataset contains fewer than :obj:`min_samples`
datapoints or if validation fails.
FileNotFoundError: If the specified JSON file path does not exist.
json.JSONDecodeError: If the JSON file contains invalid formatting.
"""
# Store all parameters in metadata dict for compatibility
self._metadata = {
**kwargs,
}
self._rng = random.Random(seed)
self._strict = strict
self.data: List[DataPoint] = self._init_data(data)
self._length = len(self.data)
if self._length < min_samples:
raise ValueError(
"The dataset does not contain enough samples. "
f"Need {max(0, min_samples)}, got {self._length}"
)
def _init_data(
self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
) -> List[DataPoint]:
r"""Convert input data from various formats into a list of
:obj:`DataPoint` instances.
Args:
data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]): Input
dataset in one of the supported formats.
Returns:
List[DataPoint]: A list of validated :obj:`DataPoint`
instances.
Raises:
TypeError: If the input data type is unsupported.
ValueError: If the Path has an unsupported file extension.
"""
if isinstance(data, HFDataset):
raw_data = self._init_from_hf_dataset(data)
elif isinstance(data, Dataset):
raw_data = self._init_from_pytorch_dataset(data)
elif isinstance(data, Path):
if data.suffix == ".jsonl":
raw_data = self._init_from_jsonl_path(data)
elif data.suffix == ".json":
raw_data = self._init_from_json_path(data)
else:
raise ValueError(
f"Unsupported file extension: {data.suffix}."
" Please enter a .json or .jsonl object."
)
elif isinstance(data, list):
raw_data = self._init_from_list(data)
else:
raise TypeError("Unsupported data type")
def create_datapoint(
item: Dict[str, Any], idx: int
) -> Optional[DataPoint]:
# Add type checks for required fields to make mypy happy
question = item.get('question')
if not isinstance(question, str):
if self._strict:
raise ValueError(
f"Sample at index {idx} has invalid 'question': "
f"expected str, got {type(question)}"
)
else:
logger.warning(
f"Skipping sample at index {idx}: invalid 'question'"
)
return None
rationale = item.get('rationale')
final_answer = item.get('final_answer')
if not isinstance(final_answer, str):
if self._strict:
raise ValueError(
f"Sample at index {idx} has invalid 'final_answer': "
f"expected str, got {type(final_answer)}"
)
else:
logger.warning(
f"Skipping sample at index {idx}: "
"invalid 'final_answer'"
)
return None
try:
return DataPoint(
question=question,
rationale=rationale,
final_answer=final_answer,
metadata=item.get('metadata'),
)
except ValidationError as e:
if self._strict:
raise ValueError(
f"Sample at index {idx} validation error: {e}"
)
else:
logger.warning(
f"Skipping invalid sample at index {idx} "
f"due to validation error: {e}"
)
return None
unfiltered_data = [
create_datapoint(item, i) for i, item in enumerate(raw_data)
]
return [dp for dp in unfiltered_data if dp is not None]
def __len__(self) -> int:
r"""Return the size of the dataset."""
return self._length
def __getitem__(
self, idx: Union[int, slice]
) -> Union[DataPoint, List[DataPoint]]:
r"""Retrieve a datapoint or a batch of datapoints by index or slice.
Args:
idx (Union[int, slice]): Index or slice of the datapoint(s).
Returns:
List[DataPoint]: A list of `DataPoint` objects.
Raises:
IndexError: If an integer `idx` is out of bounds.
"""
if isinstance(idx, int):
if idx < 0 or idx >= self._length:
raise IndexError(
f"Index {idx} out of bounds for dataset "
f"of size {self._length}"
)
return self.data[idx]
elif isinstance(idx, slice):
return self.data[idx.start : idx.stop : idx.step]
else:
raise TypeError(f"Indexing type {type(idx)} not supported.")
def sample(self) -> DataPoint:
r"""Sample a random datapoint from the dataset.
Returns:
DataPoint: A randomly sampled :obj:`DataPoint`.
Raises:
RuntimeError: If the dataset is empty and no samples can be drawn.
"""
if self._length == 0:
raise RuntimeError("Dataset is empty, cannot sample.")
idx = self._rng.randint(0, self._length - 1)
sample = self[idx]
if not isinstance(sample, DataPoint):
raise TypeError(
f"Expected DataPoint instance, got {type(sample).__name__}"
)
return sample
@property
def metadata(self) -> Dict[str, Any]:
r"""Retrieve dataset metadata.
Returns:
Dict[str, Any]: A copy of the dataset metadata dictionary.
"""
return self._metadata.copy()
def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
r"""Convert a Hugging Face dataset into a list of dictionaries.
Args:
data (HFDataset): A Hugging Face dataset.
Returns:
List[Dict[str, Any]]: A list of dictionaries representing
the dataset, where each dictionary corresponds to a datapoint.
"""
return [dict(item) for item in data]
def _init_from_pytorch_dataset(
self, data: Dataset
) -> List[Dict[str, Any]]:
r"""Convert a PyTorch dataset into a list of dictionaries.
Args:
data (Dataset): A PyTorch dataset.
Returns:
List[Dict[str, Any]]: A list of dictionaries representing
the dataset.
Raises:
TypeError: If the dataset does not implement :obj:`__len__()`
or contains non-dictionary elements.
"""
if not isinstance(data, Sized):
raise TypeError(
f"{type(data).__name__} does not implement `__len__()`."
)
raw_data = []
for i in range(len(data)):
item = data[i]
if not isinstance(item, dict):
raise TypeError(
f"Item at index {i} is not a dict: "
f"got {type(item).__name__}"
)
raw_data.append(dict(item))
return raw_data
def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
r"""Load and parse a dataset from a JSON file.
Args:
data (Path): Path to the JSON file.
Returns:
List[Dict[str, Any]]: A list of datapoint dictionaries.
Raises:
FileNotFoundError: If the specified JSON file does not exist.
ValueError: If the JSON content is not a list of dictionaries.
json.JSONDecodeError: If the JSON file has invalid formatting.
"""
if not data.exists():
raise FileNotFoundError(f"JSON file not found: {data}")
try:
logger.debug(f"Loading JSON from {data}")
with data.open('r', encoding='utf-8') as f:
loaded_data = json.load(f)
logger.info(
f"Successfully loaded {len(loaded_data)} items from {data}"
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in file {data}: {e}")
if not isinstance(loaded_data, list):
raise ValueError("JSON file must contain a list of dictionaries")
for i, item in enumerate(loaded_data):
if not isinstance(item, dict):
raise ValueError(
f"Expected a dictionary at index {i}, "
f"got {type(item).__name__}"
)
return loaded_data
def _init_from_jsonl_path(self, data: Path) -> List[Dict[str, Any]]:
r"""Load and parse a dataset from a JSONL file.
Args:
data (Path): Path to the JSONL file.
Returns:
List[Dict[str, Any]]: A list of datapoint dictionaries.
Raises:
FileNotFoundError: If the specified JSONL file does not exist.
ValueError: If a line in the file contains invalid JSON or
is not a dictionary.
"""
if not data.exists():
raise FileNotFoundError(f"JSONL file not found: {data}")
raw_data = []
logger.debug(f"Loading JSONL from {data}")
with data.open('r', encoding='utf-8') as f:
for line_number, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue # Skip blank lines if any exist.
try:
record = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON on line {line_number} in file "
f"{data}: {e}"
)
raw_data.append(record)
logger.info(f"Successfully loaded {len(raw_data)} items from {data}")
for i, item in enumerate(raw_data):
if not isinstance(item, dict):
raise ValueError(
f"Expected a dictionary at record {i+1} (line {i+1}), "
f"got {type(item).__name__}"
)
return raw_data
def _init_from_list(
self, data: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Validate and convert a list of dictionaries into a dataset.
Args:
data (List[Dict[str, Any]]): A list of dictionaries where
each dictionary must be a valid :obj:`DataPoint`.
Returns:
List[Dict[str, Any]]: The validated list of dictionaries.
Raises:
ValueError: If any item in the list is not a dictionary.
"""
for i, item in enumerate(data):
if not isinstance(item, dict):
raise ValueError(
f"Expected a dictionary at index {i}, "
f"got {type(item).__name__}"
)
return data

View File

@@ -0,0 +1,34 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from .azure_embedding import AzureEmbedding
from .base import BaseEmbedding
from .jina_embedding import JinaEmbedding
from .mistral_embedding import MistralEmbedding
from .openai_compatible_embedding import OpenAICompatibleEmbedding
from .openai_embedding import OpenAIEmbedding
from .sentence_transformers_embeddings import SentenceTransformerEncoder
from .together_embedding import TogetherEmbedding
from .vlm_embedding import VisionLanguageEmbedding
__all__ = [
"BaseEmbedding",
"OpenAIEmbedding",
"AzureEmbedding",
"SentenceTransformerEncoder",
"VisionLanguageEmbedding",
"MistralEmbedding",
"OpenAICompatibleEmbedding",
"JinaEmbedding",
"TogetherEmbedding",
]

View File

@@ -0,0 +1,119 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations
import os
from typing import Any, Union
from openai import AzureOpenAI
from camel.embeddings.base import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import api_keys_required # Add this import
class AzureEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using Azure's OpenAI models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`TEXT_EMBEDDING_3_SMALL`)
url (Optional[str], optional): The url to the Azure OpenAI service.
(default: :obj:`None`)
api_key (str, optional): The API key for authenticating with the
Azure OpenAI service. (default: :obj:`None`)
api_version (str, optional): The API version for Azure OpenAI service.
(default: :obj:`None`)
dimensions (Optional[int], optional): The text embedding output
dimensions. (default: :obj:`None`)
Raises:
RuntimeError: If an unsupported model type is specified.
ValueError: If required API configuration is missing.
"""
@api_keys_required(
[
("api_key", 'AZURE_OPENAI_API_KEY'),
("url", 'AZURE_OPENAI_BASE_URL'),
]
)
def __init__(
self,
model_type: EmbeddingModelType = (
EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
),
url: Union[str, None] = None,
api_key: Union[str, None] = None,
api_version: Union[str, None] = None,
dimensions: Union[int, None] = None,
) -> None:
self.model_type = model_type
self.api_version = api_version or os.environ.get("AZURE_API_VERSION")
if dimensions is None:
self.output_dim = model_type.output_dim
else:
if not isinstance(dimensions, int):
raise ValueError("dimensions must be an integer")
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
self._url = url or os.environ.get("AZURE_OPENAI_BASE_URL")
self.client = AzureOpenAI(
api_key=self._api_key,
api_version=self.api_version,
azure_endpoint=str(self._url),
)
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Embeds a list of texts using the Azure OpenAI model.
Args:
objs (list[str]): The list of texts to embed.
**kwargs (Any): Additional keyword arguments to pass to the API.
Returns:
list[list[float]]: The embeddings for the input texts.
"""
if self.model_type == EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
**kwargs,
)
return [data.embedding for data in response.data]
response = self.client.embeddings.create(
input=objs,
model=self.model_type.value,
dimensions=self.output_dim,
**kwargs,
)
return [data.embedding for data in response.data]
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim

Some files were not shown because too many files have changed in this diff Show More