llama-index

Форк
0
158 строк · 5.9 Кб
1
import logging
2
from typing import Any, Dict, List, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field
5
from llama_index.legacy.llms import LLM, ChatMessage, ChatResponse, OpenAI
6
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
7
from llama_index.legacy.prompts import BasePromptTemplate
8
from llama_index.legacy.prompts.default_prompts import RANKGPT_RERANK_PROMPT
9
from llama_index.legacy.prompts.mixin import PromptDictType
10
from llama_index.legacy.schema import NodeWithScore, QueryBundle
11
from llama_index.legacy.utils import print_text
12

13
logger = logging.getLogger(__name__)
14
logger.setLevel(logging.WARNING)
15

16

17
class RankGPTRerank(BaseNodePostprocessor):
18
    """RankGPT-based reranker."""
19

20
    top_n: int = Field(default=5, description="Top N nodes to return from reranking.")
21
    llm: LLM = Field(
22
        default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"),
23
        description="LLM to use for rankGPT",
24
    )
25
    verbose: bool = Field(
26
        default=False, description="Whether to print intermediate steps."
27
    )
28
    rankgpt_rerank_prompt: BasePromptTemplate = Field(
29
        description="rankGPT rerank prompt."
30
    )
31

32
    def __init__(
33
        self,
34
        top_n: int = 5,
35
        llm: Optional[LLM] = None,
36
        verbose: bool = False,
37
        rankgpt_rerank_prompt: Optional[BasePromptTemplate] = None,
38
    ):
39
        rankgpt_rerank_prompt = rankgpt_rerank_prompt or RANKGPT_RERANK_PROMPT
40
        super().__init__(
41
            verbose=verbose,
42
            llm=llm,
43
            top_n=top_n,
44
            rankgpt_rerank_prompt=rankgpt_rerank_prompt,
45
        )
46

47
    @classmethod
48
    def class_name(cls) -> str:
49
        return "RankGPTRerank"
50

51
    def _postprocess_nodes(
52
        self,
53
        nodes: List[NodeWithScore],
54
        query_bundle: Optional[QueryBundle] = None,
55
    ) -> List[NodeWithScore]:
56
        if query_bundle is None:
57
            raise ValueError("Query bundle must be provided.")
58

59
        items = {
60
            "query": query_bundle.query_str,
61
            "hits": [{"content": node.get_content()} for node in nodes],
62
        }
63

64
        messages = self.create_permutation_instruction(item=items)
65
        permutation = self.run_llm(messages=messages)
66
        if permutation.message is not None and permutation.message.content is not None:
67
            rerank_ranks = self._receive_permutation(
68
                items, str(permutation.message.content)
69
            )
70
            if self.verbose:
71
                print_text(f"After Reranking, new rank list for nodes: {rerank_ranks}")
72

73
            initial_results: List[NodeWithScore] = []
74

75
            for idx in rerank_ranks:
76
                initial_results.append(
77
                    NodeWithScore(node=nodes[idx].node, score=nodes[idx].score)
78
                )
79
            return initial_results[: self.top_n]
80
        else:
81
            return nodes[: self.top_n]
82

83
    def _get_prompts(self) -> PromptDictType:
84
        """Get prompts."""
85
        return {"rankgpt_rerank_prompt": self.rankgpt_rerank_prompt}
86

87
    def _update_prompts(self, prompts: PromptDictType) -> None:
88
        """Update prompts."""
89
        if "rankgpt_rerank_prompt" in prompts:
90
            self.rankgpt_rerank_prompt = prompts["rankgpt_rerank_prompt"]
91

92
    def _get_prefix_prompt(self, query: str, num: int) -> List[ChatMessage]:
93
        return [
94
            ChatMessage(
95
                role="system",
96
                content="You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",
97
            ),
98
            ChatMessage(
99
                role="user",
100
                content=f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",
101
            ),
102
            ChatMessage(role="assistant", content="Okay, please provide the passages."),
103
        ]
104

105
    def _get_post_prompt(self, query: str, num: int) -> str:
106
        return self.rankgpt_rerank_prompt.format(query=query, num=num)
107

108
    def create_permutation_instruction(self, item: Dict[str, Any]) -> List[ChatMessage]:
109
        query = item["query"]
110
        num = len(item["hits"])
111

112
        messages = self._get_prefix_prompt(query, num)
113
        rank = 0
114
        for hit in item["hits"]:
115
            rank += 1
116
            content = hit["content"]
117
            content = content.replace("Title: Content: ", "")
118
            content = content.strip()
119
            # For Japanese should cut by character: content = content[:int(max_length)]
120
            content = " ".join(content.split()[:300])
121
            messages.append(ChatMessage(role="user", content=f"[{rank}] {content}"))
122
            messages.append(
123
                ChatMessage(role="assistant", content=f"Received passage [{rank}].")
124
            )
125
        messages.append(
126
            ChatMessage(role="user", content=self._get_post_prompt(query, num))
127
        )
128
        return messages
129

130
    def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
131
        return self.llm.chat(messages)
132

133
    def _clean_response(self, response: str) -> str:
134
        new_response = ""
135
        for c in response:
136
            if not c.isdigit():
137
                new_response += " "
138
            else:
139
                new_response += c
140
        return new_response.strip()
141

142
    def _remove_duplicate(self, response: List[int]) -> List[int]:
143
        new_response = []
144
        for c in response:
145
            if c not in new_response:
146
                new_response.append(c)
147
        return new_response
148

149
    def _receive_permutation(self, item: Dict[str, Any], permutation: str) -> List[int]:
150
        rank_end = len(item["hits"])
151

152
        response = self._clean_response(permutation)
153
        response_list = [int(x) - 1 for x in response.split()]
154
        response_list = self._remove_duplicate(response_list)
155
        response_list = [ss for ss in response_list if ss in range(rank_end)]
156
        return response_list + [
157
            tt for tt in range(rank_end) if tt not in response_list
158
        ]  # add the rest of the rank
159

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.