llama-index
62 строки · 2.1 Кб
1from contextlib import contextmanager
2from typing import TYPE_CHECKING, Callable, Iterator
3
4from llama_index.legacy.llms.huggingface import HuggingFaceLLM
5from llama_index.legacy.llms.llama_cpp import LlamaCPP
6from llama_index.legacy.llms.llm import LLM
7
8if TYPE_CHECKING:
9from lmformatenforcer import CharacterLevelParser
10
11
12def build_lm_format_enforcer_function(
13llm: LLM, character_level_parser: "CharacterLevelParser"
14) -> Callable:
15"""Prepare for using the LM format enforcer.
16This builds the processing function that will be injected into the LLM to
17activate the LM Format Enforcer.
18"""
19if isinstance(llm, HuggingFaceLLM):
20from lmformatenforcer.integrations.transformers import (
21build_transformers_prefix_allowed_tokens_fn,
22)
23
24return build_transformers_prefix_allowed_tokens_fn(
25llm._tokenizer, character_level_parser
26)
27if isinstance(llm, LlamaCPP):
28from llama_cpp import LogitsProcessorList
29from lmformatenforcer.integrations.llamacpp import (
30build_llamacpp_logits_processor,
31)
32
33return LogitsProcessorList(
34[build_llamacpp_logits_processor(llm._model, character_level_parser)]
35)
36raise ValueError("Unsupported LLM type")
37
38
39@contextmanager
40def activate_lm_format_enforcer(
41llm: LLM, lm_format_enforcer_fn: Callable
42) -> Iterator[None]:
43"""Activate the LM Format Enforcer for the given LLM.
44
45with activate_lm_format_enforcer(llm, lm_format_enforcer_fn):
46llm.complete(...)
47"""
48if isinstance(llm, HuggingFaceLLM):
49generate_kwargs_key = "prefix_allowed_tokens_fn"
50elif isinstance(llm, LlamaCPP):
51generate_kwargs_key = "logits_processor"
52else:
53raise ValueError("Unsupported LLM type")
54llm.generate_kwargs[generate_kwargs_key] = lm_format_enforcer_fn
55
56try:
57# This is where the user code will run
58yield
59finally:
60# We remove the token enforcer function from the generate_kwargs at the end
61# in case other code paths use the same llm object.
62del llm.generate_kwargs[generate_kwargs_key]
63