llama-index

Форк
0
76 строк · 2.9 Кб
1
from typing import Any, Callable, Optional, Sequence
2

3
from llama_index.legacy.core.embeddings.base import SimilarityMode, similarity
4
from llama_index.legacy.evaluation.base import BaseEvaluator, EvaluationResult
5
from llama_index.legacy.prompts.mixin import PromptDictType
6
from llama_index.legacy.service_context import ServiceContext
7

8

9
class SemanticSimilarityEvaluator(BaseEvaluator):
10
    """Embedding similarity evaluator.
11

12
    Evaluate the quality of a question answering system by
13
    comparing the similarity between embeddings of the generated answer
14
    and the reference answer.
15

16
    Inspired by this paper:
17
    - Semantic Answer Similarity for Evaluating Question Answering Models
18
        https://arxiv.org/pdf/2108.06130.pdf
19

20
    Args:
21
        service_context (Optional[ServiceContext]): Service context.
22
        similarity_threshold (float): Embedding similarity threshold for "passing".
23
            Defaults to 0.8.
24
    """
25

26
    def __init__(
27
        self,
28
        service_context: Optional[ServiceContext] = None,
29
        similarity_fn: Optional[Callable[..., float]] = None,
30
        similarity_mode: Optional[SimilarityMode] = None,
31
        similarity_threshold: float = 0.8,
32
    ) -> None:
33
        self._service_context = service_context or ServiceContext.from_defaults()
34
        if similarity_fn is None:
35
            similarity_mode = similarity_mode or SimilarityMode.DEFAULT
36
            self._similarity_fn = lambda x, y: similarity(x, y, mode=similarity_mode)
37
        else:
38
            if similarity_mode is not None:
39
                raise ValueError(
40
                    "Cannot specify both similarity_fn and similarity_mode"
41
                )
42
            self._similarity_fn = similarity_fn
43

44
        self._similarity_threshold = similarity_threshold
45

46
    def _get_prompts(self) -> PromptDictType:
47
        """Get prompts."""
48
        return {}
49

50
    def _update_prompts(self, prompts: PromptDictType) -> None:
51
        """Update prompts."""
52

53
    async def aevaluate(
54
        self,
55
        query: Optional[str] = None,
56
        response: Optional[str] = None,
57
        contexts: Optional[Sequence[str]] = None,
58
        reference: Optional[str] = None,
59
        **kwargs: Any,
60
    ) -> EvaluationResult:
61
        del query, contexts, kwargs  # Unused
62

63
        if response is None or reference is None:
64
            raise ValueError("Must specify both response and reference")
65

66
        embed_model = self._service_context.embed_model
67
        response_embedding = await embed_model.aget_text_embedding(response)
68
        reference_embedding = await embed_model.aget_text_embedding(reference)
69

70
        similarity_score = self._similarity_fn(response_embedding, reference_embedding)
71
        passing = similarity_score >= self._similarity_threshold
72
        return EvaluationResult(
73
            score=similarity_score,
74
            passing=passing,
75
            feedback=f"Similarity score: {similarity_score}",
76
        )
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.