llama-index

Форк
0
1
"""Base retrieval abstractions."""
2

3
import asyncio
4
from abc import abstractmethod
5
from enum import Enum
6
from typing import Any, Dict, List, Optional, Tuple
7

8
from llama_index.legacy.bridge.pydantic import BaseModel, Field
9
from llama_index.legacy.evaluation.retrieval.metrics import resolve_metrics
10
from llama_index.legacy.evaluation.retrieval.metrics_base import (
11
    BaseRetrievalMetric,
12
    RetrievalMetricResult,
13
)
14
from llama_index.legacy.finetuning.embeddings.common import EmbeddingQAFinetuneDataset
15

16

17
class RetrievalEvalMode(str, Enum):
18
    """Evaluation of retrieval modality."""
19

20
    TEXT = "text"
21
    IMAGE = "image"
22

23
    @classmethod
24
    def from_str(cls, label: str) -> "RetrievalEvalMode":
25
        if label == "text":
26
            return RetrievalEvalMode.TEXT
27
        elif label == "image":
28
            return RetrievalEvalMode.IMAGE
29
        else:
30
            raise NotImplementedError
31

32

33
class RetrievalEvalResult(BaseModel):
34
    """Retrieval eval result.
35

36
    NOTE: this abstraction might change in the future.
37

38
    Attributes:
39
        query (str): Query string
40
        expected_ids (List[str]): Expected ids
41
        retrieved_ids (List[str]): Retrieved ids
42
        metric_dict (Dict[str, BaseRetrievalMetric]): \
43
            Metric dictionary for the evaluation
44

45
    """
46

47
    class Config:
48
        arbitrary_types_allowed = True
49

50
    query: str = Field(..., description="Query string")
51
    expected_ids: List[str] = Field(..., description="Expected ids")
52
    expected_texts: Optional[List[str]] = Field(
53
        default=None,
54
        description="Expected texts associated with nodes provided in `expected_ids`",
55
    )
56
    retrieved_ids: List[str] = Field(..., description="Retrieved ids")
57
    retrieved_texts: List[str] = Field(..., description="Retrieved texts")
58
    mode: "RetrievalEvalMode" = Field(
59
        default=RetrievalEvalMode.TEXT, description="text or image"
60
    )
61
    metric_dict: Dict[str, RetrievalMetricResult] = Field(
62
        ..., description="Metric dictionary for the evaluation"
63
    )
64

65
    @property
66
    def metric_vals_dict(self) -> Dict[str, float]:
67
        """Dictionary of metric values."""
68
        return {k: v.score for k, v in self.metric_dict.items()}
69

70
    def __str__(self) -> str:
71
        """String representation."""
72
        return f"Query: {self.query}\n" f"Metrics: {self.metric_vals_dict!s}\n"
73

74

75
class BaseRetrievalEvaluator(BaseModel):
76
    """Base Retrieval Evaluator class."""
77

78
    metrics: List[BaseRetrievalMetric] = Field(
79
        ..., description="List of metrics to evaluate"
80
    )
81

82
    class Config:
83
        arbitrary_types_allowed = True
84

85
    @classmethod
86
    def from_metric_names(
87
        cls, metric_names: List[str], **kwargs: Any
88
    ) -> "BaseRetrievalEvaluator":
89
        """Create evaluator from metric names.
90

91
        Args:
92
            metric_names (List[str]): List of metric names
93
            **kwargs: Additional arguments for the evaluator
94

95
        """
96
        metric_types = resolve_metrics(metric_names)
97
        return cls(metrics=[metric() for metric in metric_types], **kwargs)
98

99
    @abstractmethod
100
    async def _aget_retrieved_ids_and_texts(
101
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
102
    ) -> Tuple[List[str], List[str]]:
103
        """Get retrieved ids and texts."""
104
        raise NotImplementedError
105

106
    def evaluate(
107
        self,
108
        query: str,
109
        expected_ids: List[str],
110
        expected_texts: Optional[List[str]] = None,
111
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
112
        **kwargs: Any,
113
    ) -> RetrievalEvalResult:
114
        """Run evaluation results with query string and expected ids.
115

116
        Args:
117
            query (str): Query string
118
            expected_ids (List[str]): Expected ids
119

120
        Returns:
121
            RetrievalEvalResult: Evaluation result
122

123
        """
124
        return asyncio.run(
125
            self.aevaluate(
126
                query=query,
127
                expected_ids=expected_ids,
128
                expected_texts=expected_texts,
129
                mode=mode,
130
                **kwargs,
131
            )
132
        )
133

134
    # @abstractmethod
135
    async def aevaluate(
136
        self,
137
        query: str,
138
        expected_ids: List[str],
139
        expected_texts: Optional[List[str]] = None,
140
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
141
        **kwargs: Any,
142
    ) -> RetrievalEvalResult:
143
        """Run evaluation with query string, retrieved contexts,
144
        and generated response string.
145

146
        Subclasses can override this method to provide custom evaluation logic and
147
        take in additional arguments.
148
        """
149
        retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
150
            query, mode
151
        )
152
        metric_dict = {}
153
        for metric in self.metrics:
154
            eval_result = metric.compute(
155
                query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
156
            )
157
            metric_dict[metric.metric_name] = eval_result
158

159
        return RetrievalEvalResult(
160
            query=query,
161
            expected_ids=expected_ids,
162
            expected_texts=expected_texts,
163
            retrieved_ids=retrieved_ids,
164
            retrieved_texts=retrieved_texts,
165
            mode=mode,
166
            metric_dict=metric_dict,
167
        )
168

169
    async def aevaluate_dataset(
170
        self,
171
        dataset: EmbeddingQAFinetuneDataset,
172
        workers: int = 2,
173
        show_progress: bool = False,
174
        **kwargs: Any,
175
    ) -> List[RetrievalEvalResult]:
176
        """Run evaluation with dataset."""
177
        semaphore = asyncio.Semaphore(workers)
178

179
        async def eval_worker(
180
            query: str, expected_ids: List[str], mode: RetrievalEvalMode
181
        ) -> RetrievalEvalResult:
182
            async with semaphore:
183
                return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)
184

185
        response_jobs = []
186
        mode = RetrievalEvalMode.from_str(dataset.mode)
187
        for query_id, query in dataset.queries.items():
188
            expected_ids = dataset.relevant_docs[query_id]
189
            response_jobs.append(eval_worker(query, expected_ids, mode))
190
        if show_progress:
191
            from tqdm.asyncio import tqdm_asyncio
192

193
            eval_results = await tqdm_asyncio.gather(*response_jobs)
194
        else:
195
            eval_results = await asyncio.gather(*response_jobs)
196

197
        return eval_results
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.