llama-index

Форк
0
388 строк · 13.3 Кб
1
"""Node postprocessor."""
2

3
import logging
4
from typing import Dict, List, Optional, cast
5

6
from llama_index.legacy.bridge.pydantic import Field, validator
7
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
8
from llama_index.legacy.prompts.base import PromptTemplate
9
from llama_index.legacy.response_synthesizers import (
10
    ResponseMode,
11
    get_response_synthesizer,
12
)
13
from llama_index.legacy.schema import NodeRelationship, NodeWithScore, QueryBundle
14
from llama_index.legacy.service_context import ServiceContext
15
from llama_index.legacy.storage.docstore import BaseDocumentStore
16

17
logger = logging.getLogger(__name__)
18

19

20
class KeywordNodePostprocessor(BaseNodePostprocessor):
21
    """Keyword-based Node processor."""
22

23
    required_keywords: List[str] = Field(default_factory=list)
24
    exclude_keywords: List[str] = Field(default_factory=list)
25
    lang: str = Field(default="en")
26

27
    @classmethod
28
    def class_name(cls) -> str:
29
        return "KeywordNodePostprocessor"
30

31
    def _postprocess_nodes(
32
        self,
33
        nodes: List[NodeWithScore],
34
        query_bundle: Optional[QueryBundle] = None,
35
    ) -> List[NodeWithScore]:
36
        """Postprocess nodes."""
37
        try:
38
            import spacy
39
        except ImportError:
40
            raise ImportError(
41
                "Spacy is not installed, please install it with `pip install spacy`."
42
            )
43
        from spacy.matcher import PhraseMatcher
44

45
        nlp = spacy.blank(self.lang)
46
        required_matcher = PhraseMatcher(nlp.vocab)
47
        exclude_matcher = PhraseMatcher(nlp.vocab)
48
        required_matcher.add("RequiredKeywords", list(nlp.pipe(self.required_keywords)))
49
        exclude_matcher.add("ExcludeKeywords", list(nlp.pipe(self.exclude_keywords)))
50

51
        new_nodes = []
52
        for node_with_score in nodes:
53
            node = node_with_score.node
54
            doc = nlp(node.get_content())
55
            if self.required_keywords and not required_matcher(doc):
56
                continue
57
            if self.exclude_keywords and exclude_matcher(doc):
58
                continue
59
            new_nodes.append(node_with_score)
60

61
        return new_nodes
62

63

64
class SimilarityPostprocessor(BaseNodePostprocessor):
65
    """Similarity-based Node processor."""
66

67
    similarity_cutoff: float = Field(default=None)
68

69
    @classmethod
70
    def class_name(cls) -> str:
71
        return "SimilarityPostprocessor"
72

73
    def _postprocess_nodes(
74
        self,
75
        nodes: List[NodeWithScore],
76
        query_bundle: Optional[QueryBundle] = None,
77
    ) -> List[NodeWithScore]:
78
        """Postprocess nodes."""
79
        sim_cutoff_exists = self.similarity_cutoff is not None
80

81
        new_nodes = []
82
        for node in nodes:
83
            should_use_node = True
84
            if sim_cutoff_exists:
85
                similarity = node.score
86
                if similarity is None:
87
                    should_use_node = False
88
                elif cast(float, similarity) < cast(float, self.similarity_cutoff):
89
                    should_use_node = False
90

91
            if should_use_node:
92
                new_nodes.append(node)
93

94
        return new_nodes
95

96

97
def get_forward_nodes(
98
    node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore
99
) -> Dict[str, NodeWithScore]:
100
    """Get forward nodes."""
101
    node = node_with_score.node
102
    nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
103
    cur_count = 0
104
    # get forward nodes in an iterative manner
105
    while cur_count < num_nodes:
106
        if NodeRelationship.NEXT not in node.relationships:
107
            break
108

109
        next_node_info = node.next_node
110
        if next_node_info is None:
111
            break
112

113
        next_node_id = next_node_info.node_id
114
        next_node = docstore.get_node(next_node_id)
115
        nodes[next_node.node_id] = NodeWithScore(node=next_node)
116
        node = next_node
117
        cur_count += 1
118
    return nodes
119

120

121
def get_backward_nodes(
122
    node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore
123
) -> Dict[str, NodeWithScore]:
124
    """Get backward nodes."""
125
    node = node_with_score.node
126
    # get backward nodes in an iterative manner
127
    nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
128
    cur_count = 0
129
    while cur_count < num_nodes:
130
        prev_node_info = node.prev_node
131
        if prev_node_info is None:
132
            break
133
        prev_node_id = prev_node_info.node_id
134
        prev_node = docstore.get_node(prev_node_id)
135
        if prev_node is None:
136
            break
137
        nodes[prev_node.node_id] = NodeWithScore(node=prev_node)
138
        node = prev_node
139
        cur_count += 1
140
    return nodes
141

142

143
class PrevNextNodePostprocessor(BaseNodePostprocessor):
144
    """Previous/Next Node post-processor.
145

146
    Allows users to fetch additional nodes from the document store,
147
    based on the relationships of the nodes.
148

149
    NOTE: this is a beta feature.
150

151
    Args:
152
        docstore (BaseDocumentStore): The document store.
153
        num_nodes (int): The number of nodes to return (default: 1)
154
        mode (str): The mode of the post-processor.
155
            Can be "previous", "next", or "both.
156

157
    """
158

159
    docstore: BaseDocumentStore
160
    num_nodes: int = Field(default=1)
161
    mode: str = Field(default="next")
162

163
    @validator("mode")
164
    def _validate_mode(cls, v: str) -> str:
165
        """Validate mode."""
166
        if v not in ["next", "previous", "both"]:
167
            raise ValueError(f"Invalid mode: {v}")
168
        return v
169

170
    @classmethod
171
    def class_name(cls) -> str:
172
        return "PrevNextNodePostprocessor"
173

174
    def _postprocess_nodes(
175
        self,
176
        nodes: List[NodeWithScore],
177
        query_bundle: Optional[QueryBundle] = None,
178
    ) -> List[NodeWithScore]:
179
        """Postprocess nodes."""
180
        all_nodes: Dict[str, NodeWithScore] = {}
181
        for node in nodes:
182
            all_nodes[node.node.node_id] = node
183
            if self.mode == "next":
184
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
185
            elif self.mode == "previous":
186
                all_nodes.update(
187
                    get_backward_nodes(node, self.num_nodes, self.docstore)
188
                )
189
            elif self.mode == "both":
190
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
191
                all_nodes.update(
192
                    get_backward_nodes(node, self.num_nodes, self.docstore)
193
                )
194
            else:
195
                raise ValueError(f"Invalid mode: {self.mode}")
196

197
        all_nodes_values: List[NodeWithScore] = list(all_nodes.values())
198
        sorted_nodes: List[NodeWithScore] = []
199
        for node in all_nodes_values:
200
            # variable to check if cand node is inserted
201
            node_inserted = False
202
            for i, cand in enumerate(sorted_nodes):
203
                node_id = node.node.node_id
204
                # prepend to current candidate
205
                prev_node_info = cand.node.prev_node
206
                next_node_info = cand.node.next_node
207
                if prev_node_info is not None and node_id == prev_node_info.node_id:
208
                    node_inserted = True
209
                    sorted_nodes.insert(i, node)
210
                    break
211
                # append to current candidate
212
                elif next_node_info is not None and node_id == next_node_info.node_id:
213
                    node_inserted = True
214
                    sorted_nodes.insert(i + 1, node)
215
                    break
216

217
            if not node_inserted:
218
                sorted_nodes.append(node)
219

220
        return sorted_nodes
221

222

223
DEFAULT_INFER_PREV_NEXT_TMPL = (
224
    "The current context information is provided. \n"
225
    "A question is also provided. \n"
226
    "You are a retrieval agent deciding whether to search the "
227
    "document store for additional prior context or future context. \n"
228
    "Given the context and question, return PREVIOUS or NEXT or NONE. \n"
229
    "Examples: \n\n"
230
    "Context: Describes the author's experience at Y Combinator."
231
    "Question: What did the author do after his time at Y Combinator? \n"
232
    "Answer: NEXT \n\n"
233
    "Context: Describes the author's experience at Y Combinator."
234
    "Question: What did the author do before his time at Y Combinator? \n"
235
    "Answer: PREVIOUS \n\n"
236
    "Context: Describe the author's experience at Y Combinator."
237
    "Question: What did the author do at Y Combinator? \n"
238
    "Answer: NONE \n\n"
239
    "Context: {context_str}\n"
240
    "Question: {query_str}\n"
241
    "Answer: "
242
)
243

244

245
DEFAULT_REFINE_INFER_PREV_NEXT_TMPL = (
246
    "The current context information is provided. \n"
247
    "A question is also provided. \n"
248
    "An existing answer is also provided.\n"
249
    "You are a retrieval agent deciding whether to search the "
250
    "document store for additional prior context or future context. \n"
251
    "Given the context, question, and previous answer, "
252
    "return PREVIOUS or NEXT or NONE.\n"
253
    "Examples: \n\n"
254
    "Context: {context_msg}\n"
255
    "Question: {query_str}\n"
256
    "Existing Answer: {existing_answer}\n"
257
    "Answer: "
258
)
259

260

261
class AutoPrevNextNodePostprocessor(BaseNodePostprocessor):
262
    """Previous/Next Node post-processor.
263

264
    Allows users to fetch additional nodes from the document store,
265
    based on the prev/next relationships of the nodes.
266

267
    NOTE: difference with PrevNextPostprocessor is that
268
    this infers forward/backwards direction.
269

270
    NOTE: this is a beta feature.
271

272
    Args:
273
        docstore (BaseDocumentStore): The document store.
274
        num_nodes (int): The number of nodes to return (default: 1)
275
        infer_prev_next_tmpl (str): The template to use for inference.
276
            Required fields are {context_str} and {query_str}.
277

278
    """
279

280
    docstore: BaseDocumentStore
281
    service_context: ServiceContext
282
    num_nodes: int = Field(default=1)
283
    infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL)
284
    refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL)
285
    verbose: bool = Field(default=False)
286

287
    class Config:
288
        """Configuration for this pydantic object."""
289

290
        arbitrary_types_allowed = True
291

292
    @classmethod
293
    def class_name(cls) -> str:
294
        return "AutoPrevNextNodePostprocessor"
295

296
    def _parse_prediction(self, raw_pred: str) -> str:
297
        """Parse prediction."""
298
        pred = raw_pred.strip().lower()
299
        if "previous" in pred:
300
            return "previous"
301
        elif "next" in pred:
302
            return "next"
303
        elif "none" in pred:
304
            return "none"
305
        raise ValueError(f"Invalid prediction: {raw_pred}")
306

307
    def _postprocess_nodes(
308
        self,
309
        nodes: List[NodeWithScore],
310
        query_bundle: Optional[QueryBundle] = None,
311
    ) -> List[NodeWithScore]:
312
        """Postprocess nodes."""
313
        if query_bundle is None:
314
            raise ValueError("Missing query bundle.")
315

316
        infer_prev_next_prompt = PromptTemplate(
317
            self.infer_prev_next_tmpl,
318
        )
319
        refine_infer_prev_next_prompt = PromptTemplate(self.refine_prev_next_tmpl)
320

321
        all_nodes: Dict[str, NodeWithScore] = {}
322
        for node in nodes:
323
            all_nodes[node.node.node_id] = node
324
            # use response builder instead of llm directly
325
            # to be more robust to handling long context
326
            response_builder = get_response_synthesizer(
327
                service_context=self.service_context,
328
                text_qa_template=infer_prev_next_prompt,
329
                refine_template=refine_infer_prev_next_prompt,
330
                response_mode=ResponseMode.TREE_SUMMARIZE,
331
            )
332
            raw_pred = response_builder.get_response(
333
                text_chunks=[node.node.get_content()],
334
                query_str=query_bundle.query_str,
335
            )
336
            raw_pred = cast(str, raw_pred)
337
            mode = self._parse_prediction(raw_pred)
338

339
            logger.debug(f"> Postprocessor Predicted mode: {mode}")
340
            if self.verbose:
341
                print(f"> Postprocessor Predicted mode: {mode}")
342

343
            if mode == "next":
344
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
345
            elif mode == "previous":
346
                all_nodes.update(
347
                    get_backward_nodes(node, self.num_nodes, self.docstore)
348
                )
349
            elif mode == "none":
350
                pass
351
            else:
352
                raise ValueError(f"Invalid mode: {mode}")
353

354
        sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.node.node_id)
355
        return list(sorted_nodes)
356

357

358
class LongContextReorder(BaseNodePostprocessor):
359
    """
360
    Models struggle to access significant details found
361
    in the center of extended contexts. A study
362
    (https://arxiv.org/abs/2307.03172) observed that the best
363
    performance typically arises when crucial data is positioned
364
    at the start or conclusion of the input context. Additionally,
365
    as the input context lengthens, performance drops notably, even
366
    in models designed for long contexts.".
367
    """
368

369
    @classmethod
370
    def class_name(cls) -> str:
371
        return "LongContextReorder"
372

373
    def _postprocess_nodes(
374
        self,
375
        nodes: List[NodeWithScore],
376
        query_bundle: Optional[QueryBundle] = None,
377
    ) -> List[NodeWithScore]:
378
        """Postprocess nodes."""
379
        reordered_nodes: List[NodeWithScore] = []
380
        ordered_nodes: List[NodeWithScore] = sorted(
381
            nodes, key=lambda x: x.score if x.score is not None else 0
382
        )
383
        for i, node in enumerate(ordered_nodes):
384
            if i % 2 == 0:
385
                reordered_nodes.insert(0, node)
386
            else:
387
                reordered_nodes.append(node)
388
        return reordered_nodes
389

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.