llama-index

Форк
0
78 строк · 2.5 Кб
1
import os
2
from typing import Any, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CBEventType, EventPayload
6
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
7
from llama_index.legacy.schema import NodeWithScore, QueryBundle
8

9

10
class CohereRerank(BaseNodePostprocessor):
11
    model: str = Field(description="Cohere model name.")
12
    top_n: int = Field(description="Top N nodes to return.")
13

14
    _client: Any = PrivateAttr()
15

16
    def __init__(
17
        self,
18
        top_n: int = 2,
19
        model: str = "rerank-english-v2.0",
20
        api_key: Optional[str] = None,
21
    ):
22
        try:
23
            api_key = api_key or os.environ["COHERE_API_KEY"]
24
        except IndexError:
25
            raise ValueError(
26
                "Must pass in cohere api key or "
27
                "specify via COHERE_API_KEY environment variable "
28
            )
29
        try:
30
            from cohere import Client
31
        except ImportError:
32
            raise ImportError(
33
                "Cannot import cohere package, please `pip install cohere`."
34
            )
35

36
        self._client = Client(api_key=api_key)
37
        super().__init__(top_n=top_n, model=model)
38

39
    @classmethod
40
    def class_name(cls) -> str:
41
        return "CohereRerank"
42

43
    def _postprocess_nodes(
44
        self,
45
        nodes: List[NodeWithScore],
46
        query_bundle: Optional[QueryBundle] = None,
47
    ) -> List[NodeWithScore]:
48
        if query_bundle is None:
49
            raise ValueError("Missing query bundle in extra info.")
50
        if len(nodes) == 0:
51
            return []
52

53
        with self.callback_manager.event(
54
            CBEventType.RERANKING,
55
            payload={
56
                EventPayload.NODES: nodes,
57
                EventPayload.MODEL_NAME: self.model,
58
                EventPayload.QUERY_STR: query_bundle.query_str,
59
                EventPayload.TOP_K: self.top_n,
60
            },
61
        ) as event:
62
            texts = [node.node.get_content() for node in nodes]
63
            results = self._client.rerank(
64
                model=self.model,
65
                top_n=self.top_n,
66
                query=query_bundle.query_str,
67
                documents=texts,
68
            )
69

70
            new_nodes = []
71
            for result in results:
72
                new_node_with_score = NodeWithScore(
73
                    node=nodes[result.index].node, score=result.relevance_score
74
                )
75
                new_nodes.append(new_node_with_score)
76
            event.on_end(payload={EventPayload.NODES: new_nodes})
77

78
        return new_nodes
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.