llama-index
76 строк · 2.9 Кб
1from typing import Any, Callable, Optional, Sequence2
3from llama_index.legacy.core.embeddings.base import SimilarityMode, similarity4from llama_index.legacy.evaluation.base import BaseEvaluator, EvaluationResult5from llama_index.legacy.prompts.mixin import PromptDictType6from llama_index.legacy.service_context import ServiceContext7
8
9class SemanticSimilarityEvaluator(BaseEvaluator):10"""Embedding similarity evaluator.11
12Evaluate the quality of a question answering system by
13comparing the similarity between embeddings of the generated answer
14and the reference answer.
15
16Inspired by this paper:
17- Semantic Answer Similarity for Evaluating Question Answering Models
18https://arxiv.org/pdf/2108.06130.pdf
19
20Args:
21service_context (Optional[ServiceContext]): Service context.
22similarity_threshold (float): Embedding similarity threshold for "passing".
23Defaults to 0.8.
24"""
25
26def __init__(27self,28service_context: Optional[ServiceContext] = None,29similarity_fn: Optional[Callable[..., float]] = None,30similarity_mode: Optional[SimilarityMode] = None,31similarity_threshold: float = 0.8,32) -> None:33self._service_context = service_context or ServiceContext.from_defaults()34if similarity_fn is None:35similarity_mode = similarity_mode or SimilarityMode.DEFAULT36self._similarity_fn = lambda x, y: similarity(x, y, mode=similarity_mode)37else:38if similarity_mode is not None:39raise ValueError(40"Cannot specify both similarity_fn and similarity_mode"41)42self._similarity_fn = similarity_fn43
44self._similarity_threshold = similarity_threshold45
46def _get_prompts(self) -> PromptDictType:47"""Get prompts."""48return {}49
50def _update_prompts(self, prompts: PromptDictType) -> None:51"""Update prompts."""52
53async def aevaluate(54self,55query: Optional[str] = None,56response: Optional[str] = None,57contexts: Optional[Sequence[str]] = None,58reference: Optional[str] = None,59**kwargs: Any,60) -> EvaluationResult:61del query, contexts, kwargs # Unused62
63if response is None or reference is None:64raise ValueError("Must specify both response and reference")65
66embed_model = self._service_context.embed_model67response_embedding = await embed_model.aget_text_embedding(response)68reference_embedding = await embed_model.aget_text_embedding(reference)69
70similarity_score = self._similarity_fn(response_embedding, reference_embedding)71passing = similarity_score >= self._similarity_threshold72return EvaluationResult(73score=similarity_score,74passing=passing,75feedback=f"Similarity score: {similarity_score}",76)77