llama-index

Форк
0
156 строк · 5.7 Кб
1
"""Optimization related classes and functions."""
2

3
import logging
4
from typing import Callable, List, Optional
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.embeddings.base import BaseEmbedding
8
from llama_index.legacy.embeddings.openai import OpenAIEmbedding
9
from llama_index.legacy.indices.query.embedding_utils import get_top_k_embeddings
10
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
11
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle
12

13
logger = logging.getLogger(__name__)
14

15

16
class SentenceEmbeddingOptimizer(BaseNodePostprocessor):
17
    """Optimization of a text chunk given the query by shortening the input text."""
18

19
    percentile_cutoff: Optional[float] = Field(
20
        description="Percentile cutoff for the top k sentences to use."
21
    )
22
    threshold_cutoff: Optional[float] = Field(
23
        description="Threshold cutoff for similarity for each sentence to use."
24
    )
25

26
    _embed_model: BaseEmbedding = PrivateAttr()
27
    _tokenizer_fn: Callable[[str], List[str]] = PrivateAttr()
28

29
    context_before: Optional[int] = Field(
30
        description="Number of sentences before retrieved sentence for further context"
31
    )
32

33
    context_after: Optional[int] = Field(
34
        description="Number of sentences after retrieved sentence for further context"
35
    )
36

37
    def __init__(
38
        self,
39
        embed_model: Optional[BaseEmbedding] = None,
40
        percentile_cutoff: Optional[float] = None,
41
        threshold_cutoff: Optional[float] = None,
42
        tokenizer_fn: Optional[Callable[[str], List[str]]] = None,
43
        context_before: Optional[int] = None,
44
        context_after: Optional[int] = None,
45
    ):
46
        """Optimizer class that is passed into BaseGPTIndexQuery.
47

48
        Should be set like this:
49

50
        .. code-block:: python
51
        from llama_index.legacy.optimization.optimizer import Optimizer
52
        optimizer = SentenceEmbeddingOptimizer(
53
                        percentile_cutoff=0.5
54
                        this means that the top 50% of sentences will be used.
55
                        Alternatively, you can set the cutoff using a threshold
56
                        on the similarity score. In this case only sentences with a
57
                        similarity score higher than the threshold will be used.
58
                        threshold_cutoff=0.7
59
                        these cutoffs can also be used together.
60
                    )
61

62
        query_engine = index.as_query_engine(
63
            optimizer=optimizer
64
        )
65
        response = query_engine.query("<query_str>")
66
        """
67
        self._embed_model = embed_model or OpenAIEmbedding()
68

69
        if tokenizer_fn is None:
70
            import nltk.data
71

72
            tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
73
            tokenizer_fn = tokenizer.tokenize
74
        self._tokenizer_fn = tokenizer_fn
75

76
        super().__init__(
77
            percentile_cutoff=percentile_cutoff,
78
            threshold_cutoff=threshold_cutoff,
79
            context_after=context_after,
80
            context_before=context_before,
81
        )
82

83
    @classmethod
84
    def class_name(cls) -> str:
85
        return "SentenceEmbeddingOptimizer"
86

87
    def _postprocess_nodes(
88
        self,
89
        nodes: List[NodeWithScore],
90
        query_bundle: Optional[QueryBundle] = None,
91
    ) -> List[NodeWithScore]:
92
        """Optimize a node text given the query by shortening the node text."""
93
        if query_bundle is None:
94
            return nodes
95

96
        for node_idx in range(len(nodes)):
97
            text = nodes[node_idx].node.get_content(metadata_mode=MetadataMode.LLM)
98

99
            split_text = self._tokenizer_fn(text)
100

101
            if query_bundle.embedding is None:
102
                query_bundle.embedding = (
103
                    self._embed_model.get_agg_embedding_from_queries(
104
                        query_bundle.embedding_strs
105
                    )
106
                )
107

108
            text_embeddings = self._embed_model._get_text_embeddings(split_text)
109

110
            num_top_k = None
111
            threshold = None
112
            if self.percentile_cutoff is not None:
113
                num_top_k = int(len(split_text) * self.percentile_cutoff)
114
            if self.threshold_cutoff is not None:
115
                threshold = self.threshold_cutoff
116

117
            top_similarities, top_idxs = get_top_k_embeddings(
118
                query_embedding=query_bundle.embedding,
119
                embeddings=text_embeddings,
120
                similarity_fn=self._embed_model.similarity,
121
                similarity_top_k=num_top_k,
122
                embedding_ids=list(range(len(text_embeddings))),
123
                similarity_cutoff=threshold,
124
            )
125

126
            if len(top_idxs) == 0:
127
                raise ValueError("Optimizer returned zero sentences.")
128

129
            rangeMin, rangeMax = 0, len(split_text)
130

131
            if self.context_before is None:
132
                self.context_before = 1
133
            if self.context_after is None:
134
                self.context_after = 1
135

136
            top_sentences = [
137
                " ".join(
138
                    split_text[
139
                        max(idx - self.context_before, rangeMin) : min(
140
                            idx + self.context_after + 1, rangeMax
141
                        )
142
                    ]
143
                )
144
                for idx in top_idxs
145
            ]
146

147
            logger.debug(f"> Top {len(top_idxs)} sentences with scores:\n")
148
            if logger.isEnabledFor(logging.DEBUG):
149
                for idx in range(len(top_idxs)):
150
                    logger.debug(
151
                        f"{idx}. {top_sentences[idx]} ({top_similarities[idx]})"
152
                    )
153

154
            nodes[node_idx].node.set_content(" ".join(top_sentences))
155

156
        return nodes
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.