llama-index
158 строк · 5.9 Кб
1import logging2from typing import Any, Dict, List, Optional, Sequence3
4from llama_index.legacy.bridge.pydantic import Field5from llama_index.legacy.llms import LLM, ChatMessage, ChatResponse, OpenAI6from llama_index.legacy.postprocessor.types import BaseNodePostprocessor7from llama_index.legacy.prompts import BasePromptTemplate8from llama_index.legacy.prompts.default_prompts import RANKGPT_RERANK_PROMPT9from llama_index.legacy.prompts.mixin import PromptDictType10from llama_index.legacy.schema import NodeWithScore, QueryBundle11from llama_index.legacy.utils import print_text12
13logger = logging.getLogger(__name__)14logger.setLevel(logging.WARNING)15
16
17class RankGPTRerank(BaseNodePostprocessor):18"""RankGPT-based reranker."""19
20top_n: int = Field(default=5, description="Top N nodes to return from reranking.")21llm: LLM = Field(22default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"),23description="LLM to use for rankGPT",24)25verbose: bool = Field(26default=False, description="Whether to print intermediate steps."27)28rankgpt_rerank_prompt: BasePromptTemplate = Field(29description="rankGPT rerank prompt."30)31
32def __init__(33self,34top_n: int = 5,35llm: Optional[LLM] = None,36verbose: bool = False,37rankgpt_rerank_prompt: Optional[BasePromptTemplate] = None,38):39rankgpt_rerank_prompt = rankgpt_rerank_prompt or RANKGPT_RERANK_PROMPT40super().__init__(41verbose=verbose,42llm=llm,43top_n=top_n,44rankgpt_rerank_prompt=rankgpt_rerank_prompt,45)46
47@classmethod48def class_name(cls) -> str:49return "RankGPTRerank"50
51def _postprocess_nodes(52self,53nodes: List[NodeWithScore],54query_bundle: Optional[QueryBundle] = None,55) -> List[NodeWithScore]:56if query_bundle is None:57raise ValueError("Query bundle must be provided.")58
59items = {60"query": query_bundle.query_str,61"hits": [{"content": node.get_content()} for node in nodes],62}63
64messages = self.create_permutation_instruction(item=items)65permutation = self.run_llm(messages=messages)66if permutation.message is not None and permutation.message.content is not None:67rerank_ranks = self._receive_permutation(68items, str(permutation.message.content)69)70if self.verbose:71print_text(f"After Reranking, new rank list for nodes: {rerank_ranks}")72
73initial_results: List[NodeWithScore] = []74
75for idx in rerank_ranks:76initial_results.append(77NodeWithScore(node=nodes[idx].node, score=nodes[idx].score)78)79return initial_results[: self.top_n]80else:81return nodes[: self.top_n]82
83def _get_prompts(self) -> PromptDictType:84"""Get prompts."""85return {"rankgpt_rerank_prompt": self.rankgpt_rerank_prompt}86
87def _update_prompts(self, prompts: PromptDictType) -> None:88"""Update prompts."""89if "rankgpt_rerank_prompt" in prompts:90self.rankgpt_rerank_prompt = prompts["rankgpt_rerank_prompt"]91
92def _get_prefix_prompt(self, query: str, num: int) -> List[ChatMessage]:93return [94ChatMessage(95role="system",96content="You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",97),98ChatMessage(99role="user",100content=f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",101),102ChatMessage(role="assistant", content="Okay, please provide the passages."),103]104
105def _get_post_prompt(self, query: str, num: int) -> str:106return self.rankgpt_rerank_prompt.format(query=query, num=num)107
108def create_permutation_instruction(self, item: Dict[str, Any]) -> List[ChatMessage]:109query = item["query"]110num = len(item["hits"])111
112messages = self._get_prefix_prompt(query, num)113rank = 0114for hit in item["hits"]:115rank += 1116content = hit["content"]117content = content.replace("Title: Content: ", "")118content = content.strip()119# For Japanese should cut by character: content = content[:int(max_length)]120content = " ".join(content.split()[:300])121messages.append(ChatMessage(role="user", content=f"[{rank}] {content}"))122messages.append(123ChatMessage(role="assistant", content=f"Received passage [{rank}].")124)125messages.append(126ChatMessage(role="user", content=self._get_post_prompt(query, num))127)128return messages129
130def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:131return self.llm.chat(messages)132
133def _clean_response(self, response: str) -> str:134new_response = ""135for c in response:136if not c.isdigit():137new_response += " "138else:139new_response += c140return new_response.strip()141
142def _remove_duplicate(self, response: List[int]) -> List[int]:143new_response = []144for c in response:145if c not in new_response:146new_response.append(c)147return new_response148
149def _receive_permutation(self, item: Dict[str, Any], permutation: str) -> List[int]:150rank_end = len(item["hits"])151
152response = self._clean_response(permutation)153response_list = [int(x) - 1 for x in response.split()]154response_list = self._remove_duplicate(response_list)155response_list = [ss for ss in response_list if ss in range(rank_end)]156return response_list + [157tt for tt in range(rank_end) if tt not in response_list158] # add the rest of the rank159