llama-index

Форк
0
144 строки · 4.4 Кб
1
import os
2
from typing import Any, Callable, Dict, List, Literal, Optional, Type
3

4
import numpy as np
5

6
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
7
from llama_index.legacy.evaluation.retrieval.metrics_base import (
8
    BaseRetrievalMetric,
9
    RetrievalMetricResult,
10
)
11

12
_AGG_FUNC: Dict[str, Callable] = {"mean": np.mean, "median": np.median, "max": np.max}
13

14

15
class HitRate(BaseRetrievalMetric):
16
    """Hit rate metric."""
17

18
    metric_name: str = "hit_rate"
19

20
    def compute(
21
        self,
22
        query: Optional[str] = None,
23
        expected_ids: Optional[List[str]] = None,
24
        retrieved_ids: Optional[List[str]] = None,
25
        expected_texts: Optional[List[str]] = None,
26
        retrieved_texts: Optional[List[str]] = None,
27
        **kwargs: Any,
28
    ) -> RetrievalMetricResult:
29
        """Compute metric."""
30
        if retrieved_ids is None or expected_ids is None:
31
            raise ValueError("Retrieved ids and expected ids must be provided")
32
        is_hit = any(id in expected_ids for id in retrieved_ids)
33
        return RetrievalMetricResult(
34
            score=1.0 if is_hit else 0.0,
35
        )
36

37

38
class MRR(BaseRetrievalMetric):
39
    """MRR metric."""
40

41
    metric_name: str = "mrr"
42

43
    def compute(
44
        self,
45
        query: Optional[str] = None,
46
        expected_ids: Optional[List[str]] = None,
47
        retrieved_ids: Optional[List[str]] = None,
48
        expected_texts: Optional[List[str]] = None,
49
        retrieved_texts: Optional[List[str]] = None,
50
        **kwargs: Any,
51
    ) -> RetrievalMetricResult:
52
        """Compute metric."""
53
        if retrieved_ids is None or expected_ids is None:
54
            raise ValueError("Retrieved ids and expected ids must be provided")
55
        for i, id in enumerate(retrieved_ids):
56
            if id in expected_ids:
57
                return RetrievalMetricResult(
58
                    score=1.0 / (i + 1),
59
                )
60
        return RetrievalMetricResult(
61
            score=0.0,
62
        )
63

64

65
class CohereRerankRelevancyMetric(BaseRetrievalMetric):
66
    """Cohere rerank relevancy metric."""
67

68
    model: str = Field(description="Cohere model name.")
69
    metric_name: str = "cohere_rerank_relevancy"
70

71
    _client: Any = PrivateAttr()
72

73
    def __init__(
74
        self,
75
        model: str = "rerank-english-v2.0",
76
        api_key: Optional[str] = None,
77
    ):
78
        try:
79
            api_key = api_key or os.environ["COHERE_API_KEY"]
80
        except IndexError:
81
            raise ValueError(
82
                "Must pass in cohere api key or "
83
                "specify via COHERE_API_KEY environment variable "
84
            )
85
        try:
86
            from cohere import Client
87
        except ImportError:
88
            raise ImportError(
89
                "Cannot import cohere package, please `pip install cohere`."
90
            )
91

92
        self._client = Client(api_key=api_key)
93
        super().__init__(model=model)
94

95
    def _get_agg_func(self, agg: Literal["max", "median", "mean"]) -> Callable:
96
        """Get agg func."""
97
        return _AGG_FUNC[agg]
98

99
    def compute(
100
        self,
101
        query: Optional[str] = None,
102
        expected_ids: Optional[List[str]] = None,
103
        retrieved_ids: Optional[List[str]] = None,
104
        expected_texts: Optional[List[str]] = None,
105
        retrieved_texts: Optional[List[str]] = None,
106
        agg: Literal["max", "median", "mean"] = "max",
107
        **kwargs: Any,
108
    ) -> RetrievalMetricResult:
109
        """Compute metric."""
110
        del expected_texts  # unused
111

112
        if retrieved_texts is None:
113
            raise ValueError("Retrieved texts must be provided")
114

115
        results = self._client.rerank(
116
            model=self.model,
117
            top_n=len(
118
                retrieved_texts
119
            ),  # i.e. get a rank score for each retrieved chunk
120
            query=query,
121
            documents=retrieved_texts,
122
        )
123
        relevance_scores = [r.relevance_score for r in results]
124
        agg_func = self._get_agg_func(agg)
125

126
        return RetrievalMetricResult(
127
            score=agg_func(relevance_scores), metadata={"agg": agg}
128
        )
129

130

131
METRIC_REGISTRY: Dict[str, Type[BaseRetrievalMetric]] = {
132
    "hit_rate": HitRate,
133
    "mrr": MRR,
134
    "cohere_rerank_relevancy": CohereRerankRelevancyMetric,
135
}
136

137

138
def resolve_metrics(metrics: List[str]) -> List[Type[BaseRetrievalMetric]]:
139
    """Resolve metrics from list of metric names."""
140
    for metric in metrics:
141
        if metric not in METRIC_REGISTRY:
142
            raise ValueError(f"Invalid metric name: {metric}")
143

144
    return [METRIC_REGISTRY[metric] for metric in metrics]
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.