llama-index

Форк
0
96 строк · 3.3 Кб
1
from typing import Any, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CBEventType, EventPayload
5
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
6
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle
7
from llama_index.legacy.utils import infer_torch_device
8

9
DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH = 512
10

11

12
class SentenceTransformerRerank(BaseNodePostprocessor):
13
    model: str = Field(description="Sentence transformer model name.")
14
    top_n: int = Field(description="Number of nodes to return sorted by score.")
15
    device: str = Field(
16
        default="cpu",
17
        description="Device to use for sentence transformer.",
18
    )
19
    keep_retrieval_score: bool = Field(
20
        default=False,
21
        description="Whether to keep the retrieval score in metadata.",
22
    )
23
    _model: Any = PrivateAttr()
24

25
    def __init__(
26
        self,
27
        top_n: int = 2,
28
        model: str = "cross-encoder/stsb-distilroberta-base",
29
        device: Optional[str] = None,
30
        keep_retrieval_score: Optional[bool] = False,
31
    ):
32
        try:
33
            from sentence_transformers import CrossEncoder
34
        except ImportError:
35
            raise ImportError(
36
                "Cannot import sentence-transformers or torch package,",
37
                "please `pip install torch sentence-transformers`",
38
            )
39
        device = infer_torch_device() if device is None else device
40
        self._model = CrossEncoder(
41
            model, max_length=DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH, device=device
42
        )
43
        super().__init__(
44
            top_n=top_n,
45
            model=model,
46
            device=device,
47
            keep_retrieval_score=keep_retrieval_score,
48
        )
49

50
    @classmethod
51
    def class_name(cls) -> str:
52
        return "SentenceTransformerRerank"
53

54
    def _postprocess_nodes(
55
        self,
56
        nodes: List[NodeWithScore],
57
        query_bundle: Optional[QueryBundle] = None,
58
    ) -> List[NodeWithScore]:
59
        if query_bundle is None:
60
            raise ValueError("Missing query bundle in extra info.")
61
        if len(nodes) == 0:
62
            return []
63

64
        query_and_nodes = [
65
            (
66
                query_bundle.query_str,
67
                node.node.get_content(metadata_mode=MetadataMode.EMBED),
68
            )
69
            for node in nodes
70
        ]
71

72
        with self.callback_manager.event(
73
            CBEventType.RERANKING,
74
            payload={
75
                EventPayload.NODES: nodes,
76
                EventPayload.MODEL_NAME: self.model,
77
                EventPayload.QUERY_STR: query_bundle.query_str,
78
                EventPayload.TOP_K: self.top_n,
79
            },
80
        ) as event:
81
            scores = self._model.predict(query_and_nodes)
82

83
            assert len(scores) == len(nodes)
84

85
            for node, score in zip(nodes, scores):
86
                if self.keep_retrieval_score:
87
                    # keep the retrieval score in metadata
88
                    node.node.metadata["retrieval_score"] = node.score
89
                node.score = score
90

91
            new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
92
                : self.top_n
93
            ]
94
            event.on_end(payload={EventPayload.NODES: new_nodes})
95

96
        return new_nodes
97

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.