llama-index

Форк
0
134 строки · 4.5 Кб
1
"""Retrieval evaluators."""
2

3
from typing import Any, List, Optional, Sequence, Tuple
4

5
from llama_index.legacy.bridge.pydantic import Field
6
from llama_index.legacy.core.base_retriever import BaseRetriever
7
from llama_index.legacy.evaluation.retrieval.base import (
8
    BaseRetrievalEvaluator,
9
    RetrievalEvalMode,
10
)
11
from llama_index.legacy.evaluation.retrieval.metrics_base import (
12
    BaseRetrievalMetric,
13
)
14
from llama_index.legacy.indices.base_retriever import BaseRetriever
15
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
16
from llama_index.legacy.schema import ImageNode, TextNode
17

18

19
class RetrieverEvaluator(BaseRetrievalEvaluator):
20
    """Retriever evaluator.
21

22
    This module will evaluate a retriever using a set of metrics.
23

24
    Args:
25
        metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
26
        retriever: Retriever to evaluate.
27
        node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
28

29

30
    """
31

32
    retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
33
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
34
        default=None, description="Optional post-processor"
35
    )
36

37
    def __init__(
38
        self,
39
        metrics: Sequence[BaseRetrievalMetric],
40
        retriever: BaseRetriever,
41
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
42
        **kwargs: Any,
43
    ) -> None:
44
        """Init params."""
45
        super().__init__(
46
            metrics=metrics,
47
            retriever=retriever,
48
            node_postprocessors=node_postprocessors,
49
            **kwargs,
50
        )
51

52
    async def _aget_retrieved_ids_and_texts(
53
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
54
    ) -> Tuple[List[str], List[str]]:
55
        """Get retrieved ids and texts, potentially applying a post-processor."""
56
        retrieved_nodes = await self.retriever.aretrieve(query)
57

58
        if self.node_postprocessors:
59
            for node_postprocessor in self.node_postprocessors:
60
                retrieved_nodes = node_postprocessor.postprocess_nodes(
61
                    retrieved_nodes, query_str=query
62
                )
63

64
        return (
65
            [node.node.node_id for node in retrieved_nodes],
66
            [node.node.text for node in retrieved_nodes],
67
        )
68

69

70
class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator):
71
    """Retriever evaluator.
72

73
    This module will evaluate a retriever using a set of metrics.
74

75
    Args:
76
        metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
77
        retriever: Retriever to evaluate.
78
        node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
79

80
    """
81

82
    retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
83
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
84
        default=None, description="Optional post-processor"
85
    )
86

87
    def __init__(
88
        self,
89
        metrics: Sequence[BaseRetrievalMetric],
90
        retriever: BaseRetriever,
91
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
92
        **kwargs: Any,
93
    ) -> None:
94
        """Init params."""
95
        super().__init__(
96
            metrics=metrics,
97
            retriever=retriever,
98
            node_postprocessors=node_postprocessors,
99
            **kwargs,
100
        )
101

102
    async def _aget_retrieved_ids_texts(
103
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
104
    ) -> Tuple[List[str], List[str]]:
105
        """Get retrieved ids."""
106
        retrieved_nodes = await self.retriever.aretrieve(query)
107
        image_nodes: List[ImageNode] = []
108
        text_nodes: List[TextNode] = []
109

110
        if self.node_postprocessors:
111
            for node_postprocessor in self.node_postprocessors:
112
                retrieved_nodes = node_postprocessor.postprocess_nodes(
113
                    retrieved_nodes, query_str=query
114
                )
115

116
        for scored_node in retrieved_nodes:
117
            node = scored_node.node
118
            if isinstance(node, ImageNode):
119
                image_nodes.append(node)
120
            if node.text:
121
                text_nodes.append(node)
122

123
        if mode == "text":
124
            return (
125
                [node.node_id for node in text_nodes],
126
                [node.text for node in text_nodes],
127
            )
128
        elif mode == "image":
129
            return (
130
                [node.node_id for node in image_nodes],
131
                [node.text for node in image_nodes],
132
            )
133
        else:
134
            raise ValueError("Unsupported mode.")
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.