llama-index

Форк
0
62 строки · 2.1 Кб
1
from contextlib import contextmanager
2
from typing import TYPE_CHECKING, Callable, Iterator
3

4
from llama_index.legacy.llms.huggingface import HuggingFaceLLM
5
from llama_index.legacy.llms.llama_cpp import LlamaCPP
6
from llama_index.legacy.llms.llm import LLM
7

8
if TYPE_CHECKING:
9
    from lmformatenforcer import CharacterLevelParser
10

11

12
def build_lm_format_enforcer_function(
13
    llm: LLM, character_level_parser: "CharacterLevelParser"
14
) -> Callable:
15
    """Prepare for using the LM format enforcer.
16
    This builds the processing function that will be injected into the LLM to
17
    activate the LM Format Enforcer.
18
    """
19
    if isinstance(llm, HuggingFaceLLM):
20
        from lmformatenforcer.integrations.transformers import (
21
            build_transformers_prefix_allowed_tokens_fn,
22
        )
23

24
        return build_transformers_prefix_allowed_tokens_fn(
25
            llm._tokenizer, character_level_parser
26
        )
27
    if isinstance(llm, LlamaCPP):
28
        from llama_cpp import LogitsProcessorList
29
        from lmformatenforcer.integrations.llamacpp import (
30
            build_llamacpp_logits_processor,
31
        )
32

33
        return LogitsProcessorList(
34
            [build_llamacpp_logits_processor(llm._model, character_level_parser)]
35
        )
36
    raise ValueError("Unsupported LLM type")
37

38

39
@contextmanager
40
def activate_lm_format_enforcer(
41
    llm: LLM, lm_format_enforcer_fn: Callable
42
) -> Iterator[None]:
43
    """Activate the LM Format Enforcer for the given LLM.
44

45
    with activate_lm_format_enforcer(llm, lm_format_enforcer_fn):
46
        llm.complete(...)
47
    """
48
    if isinstance(llm, HuggingFaceLLM):
49
        generate_kwargs_key = "prefix_allowed_tokens_fn"
50
    elif isinstance(llm, LlamaCPP):
51
        generate_kwargs_key = "logits_processor"
52
    else:
53
        raise ValueError("Unsupported LLM type")
54
    llm.generate_kwargs[generate_kwargs_key] = lm_format_enforcer_fn
55

56
    try:
57
        # This is where the user code will run
58
        yield
59
    finally:
60
        # We remove the token enforcer function from the generate_kwargs at the end
61
        # in case other code paths use the same llm object.
62
        del llm.generate_kwargs[generate_kwargs_key]
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.