llama-index
499 строк · 14.0 Кб
1"""General utils functions."""
2
3import asyncio
4import os
5import random
6import sys
7import time
8import traceback
9import uuid
10from contextlib import contextmanager
11from dataclasses import dataclass
12from functools import partial, wraps
13from itertools import islice
14from pathlib import Path
15from typing import (
16Any,
17AsyncGenerator,
18Callable,
19Dict,
20Generator,
21Iterable,
22List,
23Optional,
24Protocol,
25Set,
26Type,
27Union,
28runtime_checkable,
29)
30
31
32class GlobalsHelper:
33"""Helper to retrieve globals.
34
35Helpful for global caching of certain variables that can be expensive to load.
36(e.g. tokenization)
37
38"""
39
40_stopwords: Optional[List[str]] = None
41_nltk_data_dir: Optional[str] = None
42
43def __init__(self) -> None:
44"""Initialize NLTK stopwords and punkt."""
45import nltk
46
47self._nltk_data_dir = os.environ.get(
48"NLTK_DATA",
49os.path.join(
50os.path.dirname(os.path.abspath(__file__)),
51"_static/nltk_cache",
52),
53)
54
55if self._nltk_data_dir not in nltk.data.path:
56nltk.data.path.append(self._nltk_data_dir)
57
58# ensure access to data is there
59try:
60nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
61except LookupError:
62nltk.download("stopwords", download_dir=self._nltk_data_dir)
63
64try:
65nltk.data.find("tokenizers/punkt", paths=[self._nltk_data_dir])
66except LookupError:
67nltk.download("punkt", download_dir=self._nltk_data_dir)
68
69@property
70def stopwords(self) -> List[str]:
71"""Get stopwords."""
72if self._stopwords is None:
73try:
74import nltk
75from nltk.corpus import stopwords
76except ImportError:
77raise ImportError(
78"`nltk` package not found, please run `pip install nltk`"
79)
80
81try:
82nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
83except LookupError:
84nltk.download("stopwords", download_dir=self._nltk_data_dir)
85self._stopwords = stopwords.words("english")
86return self._stopwords
87
88
89globals_helper = GlobalsHelper()
90
91
92# Global Tokenizer
93@runtime_checkable
94class Tokenizer(Protocol):
95def encode(self, text: str, *args: Any, **kwargs: Any) -> List[Any]:
96...
97
98
99def set_global_tokenizer(tokenizer: Union[Tokenizer, Callable[[str], list]]) -> None:
100import llama_index.legacy
101
102if isinstance(tokenizer, Tokenizer):
103llama_index.legacy.global_tokenizer = tokenizer.encode
104else:
105llama_index.legacy.global_tokenizer = tokenizer
106
107
108def get_tokenizer() -> Callable[[str], List]:
109import llama_index.legacy
110
111if llama_index.legacy.global_tokenizer is None:
112tiktoken_import_err = (
113"`tiktoken` package not found, please run `pip install tiktoken`"
114)
115try:
116import tiktoken
117except ImportError:
118raise ImportError(tiktoken_import_err)
119
120# set tokenizer cache temporarily
121should_revert = False
122if "TIKTOKEN_CACHE_DIR" not in os.environ:
123should_revert = True
124os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join(
125os.path.dirname(os.path.abspath(__file__)),
126"_static/tiktoken_cache",
127)
128
129enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
130tokenizer = partial(enc.encode, allowed_special="all")
131set_global_tokenizer(tokenizer)
132
133if should_revert:
134del os.environ["TIKTOKEN_CACHE_DIR"]
135
136assert llama_index.legacy.global_tokenizer is not None
137return llama_index.legacy.global_tokenizer
138
139
140def get_new_id(d: Set) -> str:
141"""Get a new ID."""
142while True:
143new_id = str(uuid.uuid4())
144if new_id not in d:
145break
146return new_id
147
148
149def get_new_int_id(d: Set) -> int:
150"""Get a new integer ID."""
151while True:
152new_id = random.randint(0, sys.maxsize)
153if new_id not in d:
154break
155return new_id
156
157
158@contextmanager
159def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
160"""Temporary setter.
161
162Utility class for setting a temporary value for an attribute on a class.
163Taken from: https://tinyurl.com/2p89xymh
164
165"""
166prev_values = {k: getattr(obj, k) for k in kwargs}
167for k, v in kwargs.items():
168setattr(obj, k, v)
169try:
170yield
171finally:
172for k, v in prev_values.items():
173setattr(obj, k, v)
174
175
176@dataclass
177class ErrorToRetry:
178"""Exception types that should be retried.
179
180Args:
181exception_cls (Type[Exception]): Class of exception.
182check_fn (Optional[Callable[[Any]], bool]]):
183A function that takes an exception instance as input and returns
184whether to retry.
185
186"""
187
188exception_cls: Type[Exception]
189check_fn: Optional[Callable[[Any], bool]] = None
190
191
192def retry_on_exceptions_with_backoff(
193lambda_fn: Callable,
194errors_to_retry: List[ErrorToRetry],
195max_tries: int = 10,
196min_backoff_secs: float = 0.5,
197max_backoff_secs: float = 60.0,
198) -> Any:
199"""Execute lambda function with retries and exponential backoff.
200
201Args:
202lambda_fn (Callable): Function to be called and output we want.
203errors_to_retry (List[ErrorToRetry]): List of errors to retry.
204At least one needs to be provided.
205max_tries (int): Maximum number of tries, including the first. Defaults to 10.
206min_backoff_secs (float): Minimum amount of backoff time between attempts.
207Defaults to 0.5.
208max_backoff_secs (float): Maximum amount of backoff time between attempts.
209Defaults to 60.
210
211"""
212if not errors_to_retry:
213raise ValueError("At least one error to retry needs to be provided")
214
215error_checks = {
216error_to_retry.exception_cls: error_to_retry.check_fn
217for error_to_retry in errors_to_retry
218}
219exception_class_tuples = tuple(error_checks.keys())
220
221backoff_secs = min_backoff_secs
222tries = 0
223
224while True:
225try:
226return lambda_fn()
227except exception_class_tuples as e:
228traceback.print_exc()
229tries += 1
230if tries >= max_tries:
231raise
232check_fn = error_checks.get(e.__class__)
233if check_fn and not check_fn(e):
234raise
235time.sleep(backoff_secs)
236backoff_secs = min(backoff_secs * 2, max_backoff_secs)
237
238
239def truncate_text(text: str, max_length: int) -> str:
240"""Truncate text to a maximum length."""
241if len(text) <= max_length:
242return text
243return text[: max_length - 3] + "..."
244
245
246def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
247"""Iterate over an iterable in batches.
248
249>>> list(iter_batch([1,2,3,4,5], 3))
250[[1, 2, 3], [4, 5]]
251"""
252source_iter = iter(iterable)
253while source_iter:
254b = list(islice(source_iter, size))
255if len(b) == 0:
256break
257yield b
258
259
260def concat_dirs(dirname: str, basename: str) -> str:
261"""
262Append basename to dirname, avoiding backslashes when running on windows.
263
264os.path.join(dirname, basename) will add a backslash before dirname if
265basename does not end with a slash, so we make sure it does.
266"""
267dirname += "/" if dirname[-1] != "/" else ""
268return os.path.join(dirname, basename)
269
270
271def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
272"""
273Optionally get a tqdm iterable. Ensures tqdm.auto is used.
274"""
275_iterator = items
276if show_progress:
277try:
278from tqdm.auto import tqdm
279
280return tqdm(items, desc=desc)
281except ImportError:
282pass
283return _iterator
284
285
286def count_tokens(text: str) -> int:
287tokenizer = get_tokenizer()
288tokens = tokenizer(text)
289return len(tokens)
290
291
292def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
293"""
294Args:
295model_name(str): the model name of the tokenizer.
296For instance, fxmarty/tiny-llama-fast-tokenizer.
297"""
298try:
299from transformers import AutoTokenizer
300except ImportError:
301raise ValueError(
302"`transformers` package not found, please run `pip install transformers`"
303)
304tokenizer = AutoTokenizer.from_pretrained(model_name)
305return tokenizer.tokenize
306
307
308def get_cache_dir() -> str:
309"""Locate a platform-appropriate cache directory for llama_index,
310and create it if it doesn't yet exist.
311"""
312# User override
313if "LLAMA_INDEX_CACHE_DIR" in os.environ:
314path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
315
316# Linux, Unix, AIX, etc.
317elif os.name == "posix" and sys.platform != "darwin":
318path = Path("/tmp/llama_index")
319
320# Mac OS
321elif sys.platform == "darwin":
322path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
323
324# Windows (hopefully)
325else:
326local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
327"~\\AppData\\Local"
328)
329path = Path(local, "llama_index")
330
331if not os.path.exists(path):
332os.makedirs(
333path, exist_ok=True
334) # prevents https://github.com/jerryjliu/llama_index/issues/7362
335return str(path)
336
337
338def add_sync_version(func: Any) -> Any:
339"""Decorator for adding sync version of an async function. The sync version
340is added as a function attribute to the original function, func.
341
342Args:
343func(Any): the async function for which a sync variant will be built.
344"""
345assert asyncio.iscoroutinefunction(func)
346
347@wraps(func)
348def _wrapper(*args: Any, **kwds: Any) -> Any:
349return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
350
351func.sync = _wrapper
352return func
353
354
355# Sample text from llama_index.legacy's readme
356SAMPLE_TEXT = """
357Context
358LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
359They are pre-trained on large amounts of publicly available data.
360How do we best augment LLMs with our own private data?
361We need a comprehensive toolkit to help perform this data augmentation for LLMs.
362
363Proposed Solution
364That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
365you build LLM apps. It provides the following tools:
366
367Offers data connectors to ingest your existing data sources and data formats
368(APIs, PDFs, docs, SQL, etc.)
369Provides ways to structure your data (indices, graphs) so that this data can be
370easily used with LLMs.
371Provides an advanced retrieval/query interface over your data:
372Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
373Allows easy integrations with your outer application framework
374(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
375LlamaIndex provides tools for both beginner users and advanced users.
376Our high-level API allows beginner users to use LlamaIndex to ingest and
377query their data in 5 lines of code. Our lower-level APIs allow advanced users to
378customize and extend any module (data connectors, indices, retrievers, query engines,
379reranking modules), to fit their needs.
380"""
381
382_LLAMA_INDEX_COLORS = {
383"llama_pink": "38;2;237;90;200",
384"llama_blue": "38;2;90;149;237",
385"llama_turquoise": "38;2;11;159;203",
386"llama_lavender": "38;2;155;135;227",
387}
388
389_ANSI_COLORS = {
390"red": "31",
391"green": "32",
392"yellow": "33",
393"blue": "34",
394"magenta": "35",
395"cyan": "36",
396"pink": "38;5;200",
397}
398
399
400def get_color_mapping(
401items: List[str], use_llama_index_colors: bool = True
402) -> Dict[str, str]:
403"""
404Get a mapping of items to colors.
405
406Args:
407items (List[str]): List of items to be mapped to colors.
408use_llama_index_colors (bool, optional): Flag to indicate
409whether to use LlamaIndex colors or ANSI colors.
410Defaults to True.
411
412Returns:
413Dict[str, str]: Mapping of items to colors.
414"""
415if use_llama_index_colors:
416color_palette = _LLAMA_INDEX_COLORS
417else:
418color_palette = _ANSI_COLORS
419
420colors = list(color_palette.keys())
421return {item: colors[i % len(colors)] for i, item in enumerate(items)}
422
423
424def _get_colored_text(text: str, color: str) -> str:
425"""
426Get the colored version of the input text.
427
428Args:
429text (str): Input text.
430color (str): Color to be applied to the text.
431
432Returns:
433str: Colored version of the input text.
434"""
435all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
436
437if color not in all_colors:
438return f"\033[1;3m{text}\033[0m" # just bolded and italicized
439
440color = all_colors[color]
441
442return f"\033[1;3;{color}m{text}\033[0m"
443
444
445def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
446"""
447Print the text with the specified color.
448
449Args:
450text (str): Text to be printed.
451color (str, optional): Color to be applied to the text. Supported colors are:
452llama_pink, llama_blue, llama_turquoise, llama_lavender,
453red, green, yellow, blue, magenta, cyan, pink.
454end (str, optional): String appended after the last character of the text.
455
456Returns:
457None
458"""
459text_to_print = _get_colored_text(text, color) if color is not None else text
460print(text_to_print, end=end)
461
462
463def infer_torch_device() -> str:
464"""Infer the input to torch.device."""
465try:
466has_cuda = torch.cuda.is_available()
467except NameError:
468import torch
469
470has_cuda = torch.cuda.is_available()
471if has_cuda:
472return "cuda"
473if torch.backends.mps.is_available():
474return "mps"
475return "cpu"
476
477
478def unit_generator(x: Any) -> Generator[Any, None, None]:
479"""A function that returns a generator of a single element.
480
481Args:
482x (Any): the element to build yield
483
484Yields:
485Any: the single element
486"""
487yield x
488
489
490async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
491"""A function that returns a generator of a single element.
492
493Args:
494x (Any): the element to build yield
495
496Yields:
497Any: the single element
498"""
499yield x
500