llama-index

Форк
0
499 строк · 14.0 Кб
1
"""General utils functions."""
2

3
import asyncio
4
import os
5
import random
6
import sys
7
import time
8
import traceback
9
import uuid
10
from contextlib import contextmanager
11
from dataclasses import dataclass
12
from functools import partial, wraps
13
from itertools import islice
14
from pathlib import Path
15
from typing import (
16
    Any,
17
    AsyncGenerator,
18
    Callable,
19
    Dict,
20
    Generator,
21
    Iterable,
22
    List,
23
    Optional,
24
    Protocol,
25
    Set,
26
    Type,
27
    Union,
28
    runtime_checkable,
29
)
30

31

32
class GlobalsHelper:
33
    """Helper to retrieve globals.
34

35
    Helpful for global caching of certain variables that can be expensive to load.
36
    (e.g. tokenization)
37

38
    """
39

40
    _stopwords: Optional[List[str]] = None
41
    _nltk_data_dir: Optional[str] = None
42

43
    def __init__(self) -> None:
44
        """Initialize NLTK stopwords and punkt."""
45
        import nltk
46

47
        self._nltk_data_dir = os.environ.get(
48
            "NLTK_DATA",
49
            os.path.join(
50
                os.path.dirname(os.path.abspath(__file__)),
51
                "_static/nltk_cache",
52
            ),
53
        )
54

55
        if self._nltk_data_dir not in nltk.data.path:
56
            nltk.data.path.append(self._nltk_data_dir)
57

58
        # ensure access to data is there
59
        try:
60
            nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
61
        except LookupError:
62
            nltk.download("stopwords", download_dir=self._nltk_data_dir)
63

64
        try:
65
            nltk.data.find("tokenizers/punkt", paths=[self._nltk_data_dir])
66
        except LookupError:
67
            nltk.download("punkt", download_dir=self._nltk_data_dir)
68

69
    @property
70
    def stopwords(self) -> List[str]:
71
        """Get stopwords."""
72
        if self._stopwords is None:
73
            try:
74
                import nltk
75
                from nltk.corpus import stopwords
76
            except ImportError:
77
                raise ImportError(
78
                    "`nltk` package not found, please run `pip install nltk`"
79
                )
80

81
            try:
82
                nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
83
            except LookupError:
84
                nltk.download("stopwords", download_dir=self._nltk_data_dir)
85
            self._stopwords = stopwords.words("english")
86
        return self._stopwords
87

88

89
globals_helper = GlobalsHelper()
90

91

92
# Global Tokenizer
93
@runtime_checkable
94
class Tokenizer(Protocol):
95
    def encode(self, text: str, *args: Any, **kwargs: Any) -> List[Any]:
96
        ...
97

98

99
def set_global_tokenizer(tokenizer: Union[Tokenizer, Callable[[str], list]]) -> None:
100
    import llama_index.legacy
101

102
    if isinstance(tokenizer, Tokenizer):
103
        llama_index.legacy.global_tokenizer = tokenizer.encode
104
    else:
105
        llama_index.legacy.global_tokenizer = tokenizer
106

107

108
def get_tokenizer() -> Callable[[str], List]:
109
    import llama_index.legacy
110

111
    if llama_index.legacy.global_tokenizer is None:
112
        tiktoken_import_err = (
113
            "`tiktoken` package not found, please run `pip install tiktoken`"
114
        )
115
        try:
116
            import tiktoken
117
        except ImportError:
118
            raise ImportError(tiktoken_import_err)
119

120
        # set tokenizer cache temporarily
121
        should_revert = False
122
        if "TIKTOKEN_CACHE_DIR" not in os.environ:
123
            should_revert = True
124
            os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join(
125
                os.path.dirname(os.path.abspath(__file__)),
126
                "_static/tiktoken_cache",
127
            )
128

129
        enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
130
        tokenizer = partial(enc.encode, allowed_special="all")
131
        set_global_tokenizer(tokenizer)
132

133
        if should_revert:
134
            del os.environ["TIKTOKEN_CACHE_DIR"]
135

136
    assert llama_index.legacy.global_tokenizer is not None
137
    return llama_index.legacy.global_tokenizer
138

139

140
def get_new_id(d: Set) -> str:
141
    """Get a new ID."""
142
    while True:
143
        new_id = str(uuid.uuid4())
144
        if new_id not in d:
145
            break
146
    return new_id
147

148

149
def get_new_int_id(d: Set) -> int:
150
    """Get a new integer ID."""
151
    while True:
152
        new_id = random.randint(0, sys.maxsize)
153
        if new_id not in d:
154
            break
155
    return new_id
156

157

158
@contextmanager
159
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
160
    """Temporary setter.
161

162
    Utility class for setting a temporary value for an attribute on a class.
163
    Taken from: https://tinyurl.com/2p89xymh
164

165
    """
166
    prev_values = {k: getattr(obj, k) for k in kwargs}
167
    for k, v in kwargs.items():
168
        setattr(obj, k, v)
169
    try:
170
        yield
171
    finally:
172
        for k, v in prev_values.items():
173
            setattr(obj, k, v)
174

175

176
@dataclass
177
class ErrorToRetry:
178
    """Exception types that should be retried.
179

180
    Args:
181
        exception_cls (Type[Exception]): Class of exception.
182
        check_fn (Optional[Callable[[Any]], bool]]):
183
            A function that takes an exception instance as input and returns
184
            whether to retry.
185

186
    """
187

188
    exception_cls: Type[Exception]
189
    check_fn: Optional[Callable[[Any], bool]] = None
190

191

192
def retry_on_exceptions_with_backoff(
193
    lambda_fn: Callable,
194
    errors_to_retry: List[ErrorToRetry],
195
    max_tries: int = 10,
196
    min_backoff_secs: float = 0.5,
197
    max_backoff_secs: float = 60.0,
198
) -> Any:
199
    """Execute lambda function with retries and exponential backoff.
200

201
    Args:
202
        lambda_fn (Callable): Function to be called and output we want.
203
        errors_to_retry (List[ErrorToRetry]): List of errors to retry.
204
            At least one needs to be provided.
205
        max_tries (int): Maximum number of tries, including the first. Defaults to 10.
206
        min_backoff_secs (float): Minimum amount of backoff time between attempts.
207
            Defaults to 0.5.
208
        max_backoff_secs (float): Maximum amount of backoff time between attempts.
209
            Defaults to 60.
210

211
    """
212
    if not errors_to_retry:
213
        raise ValueError("At least one error to retry needs to be provided")
214

215
    error_checks = {
216
        error_to_retry.exception_cls: error_to_retry.check_fn
217
        for error_to_retry in errors_to_retry
218
    }
219
    exception_class_tuples = tuple(error_checks.keys())
220

221
    backoff_secs = min_backoff_secs
222
    tries = 0
223

224
    while True:
225
        try:
226
            return lambda_fn()
227
        except exception_class_tuples as e:
228
            traceback.print_exc()
229
            tries += 1
230
            if tries >= max_tries:
231
                raise
232
            check_fn = error_checks.get(e.__class__)
233
            if check_fn and not check_fn(e):
234
                raise
235
            time.sleep(backoff_secs)
236
            backoff_secs = min(backoff_secs * 2, max_backoff_secs)
237

238

239
def truncate_text(text: str, max_length: int) -> str:
240
    """Truncate text to a maximum length."""
241
    if len(text) <= max_length:
242
        return text
243
    return text[: max_length - 3] + "..."
244

245

246
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
247
    """Iterate over an iterable in batches.
248

249
    >>> list(iter_batch([1,2,3,4,5], 3))
250
    [[1, 2, 3], [4, 5]]
251
    """
252
    source_iter = iter(iterable)
253
    while source_iter:
254
        b = list(islice(source_iter, size))
255
        if len(b) == 0:
256
            break
257
        yield b
258

259

260
def concat_dirs(dirname: str, basename: str) -> str:
261
    """
262
    Append basename to dirname, avoiding backslashes when running on windows.
263

264
    os.path.join(dirname, basename) will add a backslash before dirname if
265
    basename does not end with a slash, so we make sure it does.
266
    """
267
    dirname += "/" if dirname[-1] != "/" else ""
268
    return os.path.join(dirname, basename)
269

270

271
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
272
    """
273
    Optionally get a tqdm iterable. Ensures tqdm.auto is used.
274
    """
275
    _iterator = items
276
    if show_progress:
277
        try:
278
            from tqdm.auto import tqdm
279

280
            return tqdm(items, desc=desc)
281
        except ImportError:
282
            pass
283
    return _iterator
284

285

286
def count_tokens(text: str) -> int:
287
    tokenizer = get_tokenizer()
288
    tokens = tokenizer(text)
289
    return len(tokens)
290

291

292
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
293
    """
294
    Args:
295
        model_name(str): the model name of the tokenizer.
296
                        For instance, fxmarty/tiny-llama-fast-tokenizer.
297
    """
298
    try:
299
        from transformers import AutoTokenizer
300
    except ImportError:
301
        raise ValueError(
302
            "`transformers` package not found, please run `pip install transformers`"
303
        )
304
    tokenizer = AutoTokenizer.from_pretrained(model_name)
305
    return tokenizer.tokenize
306

307

308
def get_cache_dir() -> str:
309
    """Locate a platform-appropriate cache directory for llama_index,
310
    and create it if it doesn't yet exist.
311
    """
312
    # User override
313
    if "LLAMA_INDEX_CACHE_DIR" in os.environ:
314
        path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
315

316
    # Linux, Unix, AIX, etc.
317
    elif os.name == "posix" and sys.platform != "darwin":
318
        path = Path("/tmp/llama_index")
319

320
    # Mac OS
321
    elif sys.platform == "darwin":
322
        path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
323

324
    # Windows (hopefully)
325
    else:
326
        local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
327
            "~\\AppData\\Local"
328
        )
329
        path = Path(local, "llama_index")
330

331
    if not os.path.exists(path):
332
        os.makedirs(
333
            path, exist_ok=True
334
        )  # prevents https://github.com/jerryjliu/llama_index/issues/7362
335
    return str(path)
336

337

338
def add_sync_version(func: Any) -> Any:
339
    """Decorator for adding sync version of an async function. The sync version
340
    is added as a function attribute to the original function, func.
341

342
    Args:
343
        func(Any): the async function for which a sync variant will be built.
344
    """
345
    assert asyncio.iscoroutinefunction(func)
346

347
    @wraps(func)
348
    def _wrapper(*args: Any, **kwds: Any) -> Any:
349
        return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
350

351
    func.sync = _wrapper
352
    return func
353

354

355
# Sample text from llama_index.legacy's readme
356
SAMPLE_TEXT = """
357
Context
358
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
359
They are pre-trained on large amounts of publicly available data.
360
How do we best augment LLMs with our own private data?
361
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
362

363
Proposed Solution
364
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
365
you build LLM  apps. It provides the following tools:
366

367
Offers data connectors to ingest your existing data sources and data formats
368
(APIs, PDFs, docs, SQL, etc.)
369
Provides ways to structure your data (indices, graphs) so that this data can be
370
easily used with LLMs.
371
Provides an advanced retrieval/query interface over your data:
372
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
373
Allows easy integrations with your outer application framework
374
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
375
LlamaIndex provides tools for both beginner users and advanced users.
376
Our high-level API allows beginner users to use LlamaIndex to ingest and
377
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
378
customize and extend any module (data connectors, indices, retrievers, query engines,
379
reranking modules), to fit their needs.
380
"""
381

382
_LLAMA_INDEX_COLORS = {
383
    "llama_pink": "38;2;237;90;200",
384
    "llama_blue": "38;2;90;149;237",
385
    "llama_turquoise": "38;2;11;159;203",
386
    "llama_lavender": "38;2;155;135;227",
387
}
388

389
_ANSI_COLORS = {
390
    "red": "31",
391
    "green": "32",
392
    "yellow": "33",
393
    "blue": "34",
394
    "magenta": "35",
395
    "cyan": "36",
396
    "pink": "38;5;200",
397
}
398

399

400
def get_color_mapping(
401
    items: List[str], use_llama_index_colors: bool = True
402
) -> Dict[str, str]:
403
    """
404
    Get a mapping of items to colors.
405

406
    Args:
407
        items (List[str]): List of items to be mapped to colors.
408
        use_llama_index_colors (bool, optional): Flag to indicate
409
        whether to use LlamaIndex colors or ANSI colors.
410
            Defaults to True.
411

412
    Returns:
413
        Dict[str, str]: Mapping of items to colors.
414
    """
415
    if use_llama_index_colors:
416
        color_palette = _LLAMA_INDEX_COLORS
417
    else:
418
        color_palette = _ANSI_COLORS
419

420
    colors = list(color_palette.keys())
421
    return {item: colors[i % len(colors)] for i, item in enumerate(items)}
422

423

424
def _get_colored_text(text: str, color: str) -> str:
425
    """
426
    Get the colored version of the input text.
427

428
    Args:
429
        text (str): Input text.
430
        color (str): Color to be applied to the text.
431

432
    Returns:
433
        str: Colored version of the input text.
434
    """
435
    all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
436

437
    if color not in all_colors:
438
        return f"\033[1;3m{text}\033[0m"  # just bolded and italicized
439

440
    color = all_colors[color]
441

442
    return f"\033[1;3;{color}m{text}\033[0m"
443

444

445
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
446
    """
447
    Print the text with the specified color.
448

449
    Args:
450
        text (str): Text to be printed.
451
        color (str, optional): Color to be applied to the text. Supported colors are:
452
            llama_pink, llama_blue, llama_turquoise, llama_lavender,
453
            red, green, yellow, blue, magenta, cyan, pink.
454
        end (str, optional): String appended after the last character of the text.
455

456
    Returns:
457
        None
458
    """
459
    text_to_print = _get_colored_text(text, color) if color is not None else text
460
    print(text_to_print, end=end)
461

462

463
def infer_torch_device() -> str:
464
    """Infer the input to torch.device."""
465
    try:
466
        has_cuda = torch.cuda.is_available()
467
    except NameError:
468
        import torch
469

470
        has_cuda = torch.cuda.is_available()
471
    if has_cuda:
472
        return "cuda"
473
    if torch.backends.mps.is_available():
474
        return "mps"
475
    return "cpu"
476

477

478
def unit_generator(x: Any) -> Generator[Any, None, None]:
479
    """A function that returns a generator of a single element.
480

481
    Args:
482
        x (Any): the element to build yield
483

484
    Yields:
485
        Any: the single element
486
    """
487
    yield x
488

489

490
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
491
    """A function that returns a generator of a single element.
492

493
    Args:
494
        x (Any): the element to build yield
495

496
    Yields:
497
        Any: the single element
498
    """
499
    yield x
500

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.