llama-index
112 строк · 4.1 Кб
1"""LLM reranker."""
2
3from typing import Callable, List, Optional4
5from llama_index.legacy.bridge.pydantic import Field, PrivateAttr6from llama_index.legacy.indices.utils import (7default_format_node_batch_fn,8default_parse_choice_select_answer_fn,9)
10from llama_index.legacy.postprocessor.types import BaseNodePostprocessor11from llama_index.legacy.prompts import BasePromptTemplate12from llama_index.legacy.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT13from llama_index.legacy.prompts.mixin import PromptDictType14from llama_index.legacy.schema import NodeWithScore, QueryBundle15from llama_index.legacy.service_context import ServiceContext16
17
18class LLMRerank(BaseNodePostprocessor):19"""LLM-based reranker."""20
21top_n: int = Field(description="Top N nodes to return.")22choice_select_prompt: BasePromptTemplate = Field(23description="Choice select prompt."24)25choice_batch_size: int = Field(description="Batch size for choice select.")26service_context: ServiceContext = Field(27description="Service context.", exclude=True28)29
30_format_node_batch_fn: Callable = PrivateAttr()31_parse_choice_select_answer_fn: Callable = PrivateAttr()32
33def __init__(34self,35choice_select_prompt: Optional[BasePromptTemplate] = None,36choice_batch_size: int = 10,37format_node_batch_fn: Optional[Callable] = None,38parse_choice_select_answer_fn: Optional[Callable] = None,39service_context: Optional[ServiceContext] = None,40top_n: int = 10,41) -> None:42choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT43service_context = service_context or ServiceContext.from_defaults()44
45self._format_node_batch_fn = (46format_node_batch_fn or default_format_node_batch_fn47)48self._parse_choice_select_answer_fn = (49parse_choice_select_answer_fn or default_parse_choice_select_answer_fn50)51
52super().__init__(53choice_select_prompt=choice_select_prompt,54choice_batch_size=choice_batch_size,55service_context=service_context,56top_n=top_n,57)58
59def _get_prompts(self) -> PromptDictType:60"""Get prompts."""61return {"choice_select_prompt": self.choice_select_prompt}62
63def _update_prompts(self, prompts: PromptDictType) -> None:64"""Update prompts."""65if "choice_select_prompt" in prompts:66self.choice_select_prompt = prompts["choice_select_prompt"]67
68@classmethod69def class_name(cls) -> str:70return "LLMRerank"71
72def _postprocess_nodes(73self,74nodes: List[NodeWithScore],75query_bundle: Optional[QueryBundle] = None,76) -> List[NodeWithScore]:77if query_bundle is None:78raise ValueError("Query bundle must be provided.")79if len(nodes) == 0:80return []81
82initial_results: List[NodeWithScore] = []83for idx in range(0, len(nodes), self.choice_batch_size):84nodes_batch = [85node.node for node in nodes[idx : idx + self.choice_batch_size]86]87
88query_str = query_bundle.query_str89fmt_batch_str = self._format_node_batch_fn(nodes_batch)90# call each batch independently91raw_response = self.service_context.llm.predict(92self.choice_select_prompt,93context_str=fmt_batch_str,94query_str=query_str,95)96
97raw_choices, relevances = self._parse_choice_select_answer_fn(98raw_response, len(nodes_batch)99)100choice_idxs = [int(choice) - 1 for choice in raw_choices]101choice_nodes = [nodes_batch[idx] for idx in choice_idxs]102relevances = relevances or [1.0 for _ in choice_nodes]103initial_results.extend(104[105NodeWithScore(node=node, score=relevance)106for node, relevance in zip(choice_nodes, relevances)107]108)109
110return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[111: self.top_n112]113