llama-index

Форк
0
109 строк · 3.9 Кб
1
"""Optimization related classes and functions."""
2

3
import logging
4
from typing import Any, Dict, List, Optional
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
8
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
9

10
logger = logging.getLogger(__name__)
11

12

13
DEFAULT_INSTRUCTION_STR = "Given the context, please answer the final question"
14

15

16
class LongLLMLinguaPostprocessor(BaseNodePostprocessor):
17
    """Optimization of nodes.
18

19
    Compress using LongLLMLingua paper.
20

21
    """
22

23
    metadata_mode: MetadataMode = Field(
24
        default=MetadataMode.ALL, description="Metadata mode."
25
    )
26
    instruction_str: str = Field(
27
        default=DEFAULT_INSTRUCTION_STR, description="Instruction string."
28
    )
29
    target_token: int = Field(
30
        default=300, description="Target number of compressed tokens."
31
    )
32
    rank_method: str = Field(default="longllmlingua", description="Ranking method.")
33
    additional_compress_kwargs: Dict[str, Any] = Field(
34
        default_factory=dict, description="Additional compress kwargs."
35
    )
36

37
    _llm_lingua: Any = PrivateAttr()
38

39
    def __init__(
40
        self,
41
        model_name: str = "NousResearch/Llama-2-7b-hf",
42
        device_map: str = "cuda",
43
        model_config: Optional[dict] = {},
44
        open_api_config: Optional[dict] = {},
45
        metadata_mode: MetadataMode = MetadataMode.ALL,
46
        instruction_str: str = DEFAULT_INSTRUCTION_STR,
47
        target_token: int = 300,
48
        rank_method: str = "longllmlingua",
49
        additional_compress_kwargs: Optional[Dict[str, Any]] = None,
50
    ):
51
        """LongLLMLingua Compressor for Node Context."""
52
        from llmlingua import PromptCompressor
53

54
        open_api_config = open_api_config or {}
55
        additional_compress_kwargs = additional_compress_kwargs or {}
56

57
        self._llm_lingua = PromptCompressor(
58
            model_name=model_name,
59
            device_map=device_map,
60
            model_config=model_config,
61
            open_api_config=open_api_config,
62
        )
63
        super().__init__(
64
            metadata_mode=metadata_mode,
65
            instruction_str=instruction_str,
66
            target_token=target_token,
67
            rank_method=rank_method,
68
            additional_compress_kwargs=additional_compress_kwargs,
69
        )
70

71
    @classmethod
72
    def class_name(cls) -> str:
73
        return "LongLLMLinguaPostprocessor"
74

75
    def _postprocess_nodes(
76
        self,
77
        nodes: List[NodeWithScore],
78
        query_bundle: Optional[QueryBundle] = None,
79
    ) -> List[NodeWithScore]:
80
        """Optimize a node text given the query by shortening the node text."""
81
        if query_bundle is None:
82
            raise ValueError("Query bundle is required.")
83
        context_texts = [n.get_content(metadata_mode=self.metadata_mode) for n in nodes]
84
        # split by "\n\n" (recommended by LongLLMLingua authors)
85
        new_context_texts = [
86
            c for context in context_texts for c in context.split("\n\n")
87
        ]
88

89
        # You can use it this way, although the question-aware fine-grained compression hasn't been enabled.
90
        compressed_prompt = self._llm_lingua.compress_prompt(
91
            new_context_texts,  # ! Replace the previous context_list
92
            instruction=self.instruction_str,
93
            question=query_bundle.query_str,
94
            # target_token=2000,
95
            target_token=self.target_token,
96
            rank_method=self.rank_method,
97
            **self.additional_compress_kwargs,
98
        )
99

100
        compressed_prompt_txt = compressed_prompt["compressed_prompt"]
101

102
        # separate out the question and instruction (appended to top and bottom)
103
        compressed_prompt_txt_list = compressed_prompt_txt.split("\n\n")
104
        compressed_prompt_txt_list = compressed_prompt_txt_list[1:-1]
105

106
        # return nodes for each list
107
        return [
108
            NodeWithScore(node=TextNode(text=t)) for t in compressed_prompt_txt_list
109
        ]
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.