llama-index

Форк
0
228 строк · 7.6 Кб
1
"""Node recency post-processor."""
2

3
from datetime import datetime
4
from typing import List, Optional, Set
5

6
import numpy as np
7
import pandas as pd
8

9
from llama_index.legacy.bridge.pydantic import Field
10
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
11
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle
12
from llama_index.legacy.service_context import ServiceContext
13

14
# NOTE: currently not being used
15
# DEFAULT_INFER_RECENCY_TMPL = (
16
#     "A question is provided.\n"
17
#     "The goal is to determine whether the question requires finding the most recent "
18
#     "context.\n"
19
#     "Please respond with YES or NO.\n"
20
#     "Question: What is the current status of the patient?\n"
21
#     "Answer: YES\n"
22
#     "Question: What happened in the Battle of Yorktown?\n"
23
#     "Answer: NO\n"
24
#     "Question: What are the most recent changes to the project?\n"
25
#     "Answer: YES\n"
26
#     "Question: How did Harry defeat Voldemort in the Battle of Hogwarts?\n"
27
#     "Answer: NO\n"
28
#     "Question: {query_str}\n"
29
#     "Answer: "
30
# )
31

32

33
# def parse_recency_pred(pred: str) -> bool:
34
#     """Parse recency prediction."""
35
#     if "YES" in pred:
36
#         return True
37
#     elif "NO" in pred:
38
#         return False
39
#     else:
40
#         raise ValueError(f"Invalid recency prediction: {pred}.")
41

42

43
class FixedRecencyPostprocessor(BaseNodePostprocessor):
44
    """Recency post-processor.
45

46
    This post-processor does the following steps:
47

48
    - Decides if we need to use the post-processor given the query
49
      (is it temporal-related?)
50
    - If yes, sorts nodes by date.
51
    - Take the first k nodes (by default 1), and use that to synthesize an answer.
52

53
    """
54

55
    service_context: ServiceContext
56
    top_k: int = 1
57
    # infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL)
58
    date_key: str = "date"
59

60
    @classmethod
61
    def class_name(cls) -> str:
62
        return "FixedRecencyPostprocessor"
63

64
    def _postprocess_nodes(
65
        self,
66
        nodes: List[NodeWithScore],
67
        query_bundle: Optional[QueryBundle] = None,
68
    ) -> List[NodeWithScore]:
69
        """Postprocess nodes."""
70
        if query_bundle is None:
71
            raise ValueError("Missing query bundle in extra info.")
72

73
        # sort nodes by date
74
        node_dates = pd.to_datetime(
75
            [node.node.metadata[self.date_key] for node in nodes]
76
        )
77
        sorted_node_idxs = np.flip(node_dates.argsort())
78
        sorted_nodes = [nodes[idx] for idx in sorted_node_idxs]
79

80
        return sorted_nodes[: self.top_k]
81

82

83
DEFAULT_QUERY_EMBEDDING_TMPL = (
84
    "The current document is provided.\n"
85
    "----------------\n"
86
    "{context_str}\n"
87
    "----------------\n"
88
    "Given the document, we wish to find documents that contain \n"
89
    "similar context. Note that these documents are older "
90
    "than the current document, meaning that certain details may be changed. \n"
91
    "However, the high-level context should be similar.\n"
92
)
93

94

95
class EmbeddingRecencyPostprocessor(BaseNodePostprocessor):
96
    """Recency post-processor.
97

98
    This post-processor does the following steps:
99

100
    - Decides if we need to use the post-processor given the query
101
      (is it temporal-related?)
102
    - If yes, sorts nodes by date.
103
    - For each node, look at subsequent nodes and filter out nodes
104
      that have high embedding similarity with the current node.
105
      Because this means the subsequent node may have overlapping content
106
      with the current node but is also out of date
107
    """
108

109
    service_context: ServiceContext
110
    # infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL)
111
    date_key: str = "date"
112
    similarity_cutoff: float = Field(default=0.7)
113
    query_embedding_tmpl: str = Field(default=DEFAULT_QUERY_EMBEDDING_TMPL)
114

115
    @classmethod
116
    def class_name(cls) -> str:
117
        return "EmbeddingRecencyPostprocessor"
118

119
    def _postprocess_nodes(
120
        self,
121
        nodes: List[NodeWithScore],
122
        query_bundle: Optional[QueryBundle] = None,
123
    ) -> List[NodeWithScore]:
124
        """Postprocess nodes."""
125
        if query_bundle is None:
126
            raise ValueError("Missing query bundle in extra info.")
127

128
        # sort nodes by date
129
        node_dates = pd.to_datetime(
130
            [node.node.metadata[self.date_key] for node in nodes]
131
        )
132
        sorted_node_idxs = np.flip(node_dates.argsort())
133
        sorted_nodes: List[NodeWithScore] = [nodes[idx] for idx in sorted_node_idxs]
134

135
        # get embeddings for each node
136
        embed_model = self.service_context.embed_model
137
        texts = [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes]
138
        text_embeddings = embed_model.get_text_embedding_batch(texts=texts)
139

140
        node_ids_to_skip: Set[str] = set()
141
        for idx, node in enumerate(sorted_nodes):
142
            if node.node.node_id in node_ids_to_skip:
143
                continue
144
            # get query embedding for the "query" node
145
            # NOTE: not the same as the text embedding because
146
            # we want to optimize for retrieval results
147

148
            query_text = self.query_embedding_tmpl.format(
149
                context_str=node.node.get_content(metadata_mode=MetadataMode.EMBED),
150
            )
151
            query_embedding = embed_model.get_query_embedding(query_text)
152

153
            for idx2 in range(idx + 1, len(sorted_nodes)):
154
                if sorted_nodes[idx2].node.node_id in node_ids_to_skip:
155
                    continue
156
                node2 = sorted_nodes[idx2]
157
                if (
158
                    np.dot(query_embedding, text_embeddings[idx2])
159
                    > self.similarity_cutoff
160
                ):
161
                    node_ids_to_skip.add(node2.node.node_id)
162

163
        return [
164
            node for node in sorted_nodes if node.node.node_id not in node_ids_to_skip
165
        ]
166

167

168
class TimeWeightedPostprocessor(BaseNodePostprocessor):
169
    """Time-weighted post-processor.
170

171
    Reranks a set of nodes based on their recency.
172

173
    """
174

175
    time_decay: float = Field(default=0.99)
176
    last_accessed_key: str = "__last_accessed__"
177
    time_access_refresh: bool = True
178
    # optionally set now (makes it easier to test)
179
    now: Optional[float] = None
180
    top_k: int = 1
181

182
    @classmethod
183
    def class_name(cls) -> str:
184
        return "TimeWeightedPostprocessor"
185

186
    def _postprocess_nodes(
187
        self,
188
        nodes: List[NodeWithScore],
189
        query_bundle: Optional[QueryBundle] = None,
190
    ) -> List[NodeWithScore]:
191
        """Postprocess nodes."""
192
        now = self.now or datetime.now().timestamp()
193
        # TODO: refactor with get_top_k_embeddings
194

195
        similarities = []
196
        for node_with_score in nodes:
197
            # embedding similarity score
198
            score = node_with_score.score or 1.0
199
            node = node_with_score.node
200
            # time score
201
            if node.metadata is None:
202
                raise ValueError("metadata is None")
203

204
            last_accessed = node.metadata.get(self.last_accessed_key, None)
205
            if last_accessed is None:
206
                last_accessed = now
207

208
            hours_passed = (now - last_accessed) / 3600
209
            time_similarity = (1 - self.time_decay) ** hours_passed
210

211
            similarity = score + time_similarity
212

213
            similarities.append(similarity)
214

215
        sorted_tups = sorted(zip(similarities, nodes), key=lambda x: x[0], reverse=True)
216

217
        top_k = min(self.top_k, len(sorted_tups))
218
        result_tups = sorted_tups[:top_k]
219
        result_nodes = [
220
            NodeWithScore(node=n.node, score=score) for score, n in result_tups
221
        ]
222

223
        # set __last_accessed__ to now
224
        if self.time_access_refresh:
225
            for node_with_score in result_nodes:
226
                node_with_score.node.metadata[self.last_accessed_key] = now
227

228
        return result_nodes
229

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.