llama-index

Форк
0
83 строки · 2.7 Кб
1
from typing import Any, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CBEventType, EventPayload
5
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
6
from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle
7

8

9
class FlagEmbeddingReranker(BaseNodePostprocessor):
10
    """Flag Embedding Reranker."""
11

12
    model: str = Field(description="BAAI Reranker model name.")
13
    top_n: int = Field(description="Number of nodes to return sorted by score.")
14
    use_fp16: bool = Field(description="Whether to use fp16 for inference.")
15
    _model: Any = PrivateAttr()
16

17
    def __init__(
18
        self,
19
        top_n: int = 2,
20
        model: str = "BAAI/bge-reranker-large",
21
        use_fp16: bool = False,
22
    ) -> None:
23
        try:
24
            from FlagEmbedding import FlagReranker
25
        except ImportError:
26
            raise ImportError(
27
                "Cannot import FlagReranker package, please install it: ",
28
                "pip install git+https://github.com/FlagOpen/FlagEmbedding.git",
29
            )
30
        self._model = FlagReranker(
31
            model,
32
            use_fp16=use_fp16,
33
        )
34
        super().__init__(top_n=top_n, model=model, use_fp16=use_fp16)
35

36
    @classmethod
37
    def class_name(cls) -> str:
38
        return "FlagEmbeddingReranker"
39

40
    def _postprocess_nodes(
41
        self,
42
        nodes: List[NodeWithScore],
43
        query_bundle: Optional[QueryBundle] = None,
44
    ) -> List[NodeWithScore]:
45
        if query_bundle is None:
46
            raise ValueError("Missing query bundle in extra info.")
47
        if len(nodes) == 0:
48
            return []
49

50
        query_and_nodes = [
51
            (
52
                query_bundle.query_str,
53
                node.node.get_content(metadata_mode=MetadataMode.EMBED),
54
            )
55
            for node in nodes
56
        ]
57

58
        with self.callback_manager.event(
59
            CBEventType.RERANKING,
60
            payload={
61
                EventPayload.NODES: nodes,
62
                EventPayload.MODEL_NAME: self.model,
63
                EventPayload.QUERY_STR: query_bundle.query_str,
64
                EventPayload.TOP_K: self.top_n,
65
            },
66
        ) as event:
67
            scores = self._model.compute_score(query_and_nodes)
68

69
            # a single node passed into compute_score returns a float
70
            if isinstance(scores, float):
71
                scores = [scores]
72

73
            assert len(scores) == len(nodes)
74

75
            for node, score in zip(nodes, scores):
76
                node.score = score
77

78
            new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
79
                : self.top_n
80
            ]
81
            event.on_end(payload={EventPayload.NODES: new_nodes})
82

83
        return new_nodes
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.