llama-index
82 строки · 2.6 Кб
1# Modified from:
2# https://github.com/nyno-ai/openai-token-counter
3
4from typing import Any, Callable, Dict, List, Optional5
6from llama_index.legacy.llms import ChatMessage, MessageRole7from llama_index.legacy.utils import get_tokenizer8
9
10class TokenCounter:11"""Token counter class.12
13Attributes:
14model (Optional[str]): The model to use for token counting.
15"""
16
17def __init__(self, tokenizer: Optional[Callable[[str], list]] = None) -> None:18self.tokenizer = tokenizer or get_tokenizer()19
20def get_string_tokens(self, string: str) -> int:21"""Get the token count for a string.22
23Args:
24string (str): The string to count.
25
26Returns:
27int: The token count.
28"""
29return len(self.tokenizer(string))30
31def estimate_tokens_in_messages(self, messages: List[ChatMessage]) -> int:32"""Estimate token count for a single message.33
34Args:
35message (OpenAIMessage): The message to estimate the token count for.
36
37Returns:
38int: The estimated token count.
39"""
40tokens = 041
42for message in messages:43if message.role:44tokens += self.get_string_tokens(message.role)45
46if message.content:47tokens += self.get_string_tokens(message.content)48
49additional_kwargs = {**message.additional_kwargs}50
51if "function_call" in additional_kwargs:52function_call = additional_kwargs.pop("function_call")53if function_call.get("name", None) is not None:54tokens += self.get_string_tokens(function_call["name"])55
56if function_call.get("arguments", None) is not None:57tokens += self.get_string_tokens(function_call["arguments"])58
59tokens += 3 # Additional tokens for function call60
61tokens += 3 # Add three per message62
63if message.role == MessageRole.FUNCTION:64tokens -= 2 # Subtract 2 if role is "function"65
66return tokens67
68def estimate_tokens_in_functions(self, functions: List[Dict[str, Any]]) -> int:69"""Estimate token count for the functions.70
71We take here a list of functions created using the `to_openai_spec` function (or similar).
72
73Args:
74function (list[Dict[str, Any]]): The functions to estimate the token count for.
75
76Returns:
77int: The estimated token count.
78"""
79prompt_definition = str(functions)80tokens = self.get_string_tokens(prompt_definition)81tokens += 9 # Additional tokens for function definition82return tokens83