llama-index
78 строк · 2.5 Кб
1import os2from typing import Any, List, Optional3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.callbacks import CBEventType, EventPayload6from llama_index.legacy.postprocessor.types import BaseNodePostprocessor7from llama_index.legacy.schema import NodeWithScore, QueryBundle8
9
10class CohereRerank(BaseNodePostprocessor):11model: str = Field(description="Cohere model name.")12top_n: int = Field(description="Top N nodes to return.")13
14_client: Any = PrivateAttr()15
16def __init__(17self,18top_n: int = 2,19model: str = "rerank-english-v2.0",20api_key: Optional[str] = None,21):22try:23api_key = api_key or os.environ["COHERE_API_KEY"]24except IndexError:25raise ValueError(26"Must pass in cohere api key or "27"specify via COHERE_API_KEY environment variable "28)29try:30from cohere import Client31except ImportError:32raise ImportError(33"Cannot import cohere package, please `pip install cohere`."34)35
36self._client = Client(api_key=api_key)37super().__init__(top_n=top_n, model=model)38
39@classmethod40def class_name(cls) -> str:41return "CohereRerank"42
43def _postprocess_nodes(44self,45nodes: List[NodeWithScore],46query_bundle: Optional[QueryBundle] = None,47) -> List[NodeWithScore]:48if query_bundle is None:49raise ValueError("Missing query bundle in extra info.")50if len(nodes) == 0:51return []52
53with self.callback_manager.event(54CBEventType.RERANKING,55payload={56EventPayload.NODES: nodes,57EventPayload.MODEL_NAME: self.model,58EventPayload.QUERY_STR: query_bundle.query_str,59EventPayload.TOP_K: self.top_n,60},61) as event:62texts = [node.node.get_content() for node in nodes]63results = self._client.rerank(64model=self.model,65top_n=self.top_n,66query=query_bundle.query_str,67documents=texts,68)69
70new_nodes = []71for result in results:72new_node_with_score = NodeWithScore(73node=nodes[result.index].node, score=result.relevance_score74)75new_nodes.append(new_node_with_score)76event.on_end(payload={EventPayload.NODES: new_nodes})77
78return new_nodes79