fix token count error in use case

This commit is contained in:
Wendong
2025-05-11 18:59:10 +08:00
parent 590ba05a43
commit 18a7f89097

View File

@@ -13,11 +13,9 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
import sys
import contextlib
import time
from pathlib import Path
from typing import List, Dict, Any
import os
from colorama import Fore, init
from dotenv import load_dotenv
@@ -26,7 +24,6 @@ from camel.agents.chat_agent import ToolCallingRecord
from camel.models import ModelFactory
from camel.toolkits import FunctionTool, MCPToolkit
from camel.types import ModelPlatformType, ModelType
from camel.logger import set_log_level
from owl.utils.enhanced_role_playing import OwlRolePlaying
@@ -39,8 +36,6 @@ base_dir = pathlib.Path(__file__).parent.parent
env_path = base_dir / "owl" / ".env"
load_dotenv(dotenv_path=str(env_path))
set_log_level(level="INFO")
async def construct_society(
question: str,
@@ -146,9 +141,6 @@ def write_to_md(filename: str, content: Dict[str, Any]) -> None:
if "summary" in content:
f.write(f"## Summary\n\n")
f.write(f"{content['summary']}\n\n")
if "token_count" in content:
f.write(f"**Total tokens used:** {content['token_count']}\n\n")
async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename: str, round_limit: int = 15):
@@ -182,12 +174,25 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
input_msg = society.init_chat()
chat_history = []
token_count = {"total": 0}
overall_completion_token_count = 0
overall_prompt_token_count = 0
n = 0
while n < round_limit:
n += 1
assistant_response, user_response = await society.astep(input_msg)
overall_completion_token_count += assistant_response.info["usage"].get(
"completion_tokens", 0
) + user_response.info["usage"].get("completion_tokens", 0)
overall_prompt_token_count += assistant_response.info["usage"].get(
"prompt_tokens", 0
) + user_response.info["usage"].get("prompt_tokens", 0)
token_info = {
"completion_token_count": overall_completion_token_count,
"prompt_token_count": overall_prompt_token_count,
}
md_content = {}
@@ -239,10 +244,6 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
"user": user_response.msg.content,
})
# Update token count
if "token_count" in assistant_response.info:
token_count["total"] += assistant_response.info["token_count"]
if "TASK_DONE" in user_response.msg.content:
task_done_msg = "Task completed successfully!"
print(Fore.YELLOW + task_done_msg + "\n")
@@ -252,12 +253,12 @@ async def run_society_with_formatted_output(society: OwlRolePlaying, md_filename
input_msg = assistant_response.msg
# Write token count information
write_to_md(md_filename, {"token_count": token_count["total"]})
write_to_md(md_filename, token_info)
# Extract final answer
answer = assistant_response.msg.content if assistant_response and assistant_response.msg else ""
return answer, chat_history, token_count
return answer, chat_history, token_info
async def main():
@@ -278,7 +279,7 @@ async def main():
# Use command line argument if provided, otherwise use default task
task = sys.argv[1] if len(sys.argv) > 1 else default_task
mcp_toolkit = MCPToolkit(config_path=str(config_path))
mcp_toolkit = MCPToolkit(config_path=str(config_path), strict=True)
try:
# Create markdown file for conversation export
@@ -294,10 +295,10 @@ async def main():
# Build and run society
print(Fore.YELLOW + f"Starting task: {task}\n")
society = await construct_society(task, tools)
answer, chat_history, token_count = await run_society_with_formatted_output(society, md_filename)
answer, chat_history, token_info = await run_society_with_formatted_output(society, md_filename)
print(Fore.GREEN + f"\nFinal Result: {answer}")
print(Fore.CYAN + f"Total tokens used: {token_count['total']}")
print(Fore.CYAN + f"Total tokens used: {token_info}")
print(Fore.CYAN + f"Full conversation log saved to: {md_filename}")
except KeyboardInterrupt: