llama-index

Форк
0
112 строк · 4.1 Кб
1
"""LLM reranker."""
2

3
from typing import Callable, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
6
from llama_index.legacy.indices.utils import (
7
    default_format_node_batch_fn,
8
    default_parse_choice_select_answer_fn,
9
)
10
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
11
from llama_index.legacy.prompts import BasePromptTemplate
12
from llama_index.legacy.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
13
from llama_index.legacy.prompts.mixin import PromptDictType
14
from llama_index.legacy.schema import NodeWithScore, QueryBundle
15
from llama_index.legacy.service_context import ServiceContext
16

17

18
class LLMRerank(BaseNodePostprocessor):
19
    """LLM-based reranker."""
20

21
    top_n: int = Field(description="Top N nodes to return.")
22
    choice_select_prompt: BasePromptTemplate = Field(
23
        description="Choice select prompt."
24
    )
25
    choice_batch_size: int = Field(description="Batch size for choice select.")
26
    service_context: ServiceContext = Field(
27
        description="Service context.", exclude=True
28
    )
29

30
    _format_node_batch_fn: Callable = PrivateAttr()
31
    _parse_choice_select_answer_fn: Callable = PrivateAttr()
32

33
    def __init__(
34
        self,
35
        choice_select_prompt: Optional[BasePromptTemplate] = None,
36
        choice_batch_size: int = 10,
37
        format_node_batch_fn: Optional[Callable] = None,
38
        parse_choice_select_answer_fn: Optional[Callable] = None,
39
        service_context: Optional[ServiceContext] = None,
40
        top_n: int = 10,
41
    ) -> None:
42
        choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
43
        service_context = service_context or ServiceContext.from_defaults()
44

45
        self._format_node_batch_fn = (
46
            format_node_batch_fn or default_format_node_batch_fn
47
        )
48
        self._parse_choice_select_answer_fn = (
49
            parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
50
        )
51

52
        super().__init__(
53
            choice_select_prompt=choice_select_prompt,
54
            choice_batch_size=choice_batch_size,
55
            service_context=service_context,
56
            top_n=top_n,
57
        )
58

59
    def _get_prompts(self) -> PromptDictType:
60
        """Get prompts."""
61
        return {"choice_select_prompt": self.choice_select_prompt}
62

63
    def _update_prompts(self, prompts: PromptDictType) -> None:
64
        """Update prompts."""
65
        if "choice_select_prompt" in prompts:
66
            self.choice_select_prompt = prompts["choice_select_prompt"]
67

68
    @classmethod
69
    def class_name(cls) -> str:
70
        return "LLMRerank"
71

72
    def _postprocess_nodes(
73
        self,
74
        nodes: List[NodeWithScore],
75
        query_bundle: Optional[QueryBundle] = None,
76
    ) -> List[NodeWithScore]:
77
        if query_bundle is None:
78
            raise ValueError("Query bundle must be provided.")
79
        if len(nodes) == 0:
80
            return []
81

82
        initial_results: List[NodeWithScore] = []
83
        for idx in range(0, len(nodes), self.choice_batch_size):
84
            nodes_batch = [
85
                node.node for node in nodes[idx : idx + self.choice_batch_size]
86
            ]
87

88
            query_str = query_bundle.query_str
89
            fmt_batch_str = self._format_node_batch_fn(nodes_batch)
90
            # call each batch independently
91
            raw_response = self.service_context.llm.predict(
92
                self.choice_select_prompt,
93
                context_str=fmt_batch_str,
94
                query_str=query_str,
95
            )
96

97
            raw_choices, relevances = self._parse_choice_select_answer_fn(
98
                raw_response, len(nodes_batch)
99
            )
100
            choice_idxs = [int(choice) - 1 for choice in raw_choices]
101
            choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
102
            relevances = relevances or [1.0 for _ in choice_nodes]
103
            initial_results.extend(
104
                [
105
                    NodeWithScore(node=node, score=relevance)
106
                    for node, relevance in zip(choice_nodes, relevances)
107
                ]
108
            )
109

110
        return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
111
            : self.top_n
112
        ]
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.