llama-index

Форк
0
121 строка · 4.2 Кб
1
"""Guideline evaluation."""
2

3
import asyncio
4
import logging
5
from typing import Any, Optional, Sequence, Union, cast
6

7
from llama_index.legacy import ServiceContext
8
from llama_index.legacy.bridge.pydantic import BaseModel, Field
9
from llama_index.legacy.evaluation.base import BaseEvaluator, EvaluationResult
10
from llama_index.legacy.output_parsers import PydanticOutputParser
11
from llama_index.legacy.prompts import BasePromptTemplate, PromptTemplate
12
from llama_index.legacy.prompts.mixin import PromptDictType
13

14
logger = logging.getLogger(__name__)
15

16

17
DEFAULT_GUIDELINES = (
18
    "The response should fully answer the query.\n"
19
    "The response should avoid being vague or ambiguous.\n"
20
    "The response should be specific and use statistics or numbers when possible.\n"
21
)
22

23
DEFAULT_EVAL_TEMPLATE = PromptTemplate(
24
    "Here is the original query:\n"
25
    "Query: {query}\n"
26
    "Critique the following response based on the guidelines below:\n"
27
    "Response: {response}\n"
28
    "Guidelines: {guidelines}\n"
29
    "Now please provide constructive criticism.\n"
30
)
31

32

33
class EvaluationData(BaseModel):
34
    passing: bool = Field(description="Whether the response passes the guidelines.")
35
    feedback: str = Field(
36
        description="The feedback for the response based on the guidelines."
37
    )
38

39

40
class GuidelineEvaluator(BaseEvaluator):
41
    """Guideline evaluator.
42

43
    Evaluates whether a query and response pair passes the given guidelines.
44

45
    This evaluator only considers the query string and the response string.
46

47
    Args:
48
        service_context(Optional[ServiceContext]):
49
            The service context to use for evaluation.
50
        guidelines(Optional[str]): User-added guidelines to use for evaluation.
51
            Defaults to None, which uses the default guidelines.
52
        eval_template(Optional[Union[str, BasePromptTemplate]] ):
53
            The template to use for evaluation.
54
    """
55

56
    def __init__(
57
        self,
58
        service_context: Optional[ServiceContext] = None,
59
        guidelines: Optional[str] = None,
60
        eval_template: Optional[Union[str, BasePromptTemplate]] = None,
61
    ) -> None:
62
        self._service_context = service_context or ServiceContext.from_defaults()
63
        self._guidelines = guidelines or DEFAULT_GUIDELINES
64

65
        self._eval_template: BasePromptTemplate
66
        if isinstance(eval_template, str):
67
            self._eval_template = PromptTemplate(eval_template)
68
        else:
69
            self._eval_template = eval_template or DEFAULT_EVAL_TEMPLATE
70

71
        self._output_parser = PydanticOutputParser(output_cls=EvaluationData)
72
        self._eval_template.output_parser = self._output_parser
73

74
    def _get_prompts(self) -> PromptDictType:
75
        """Get prompts."""
76
        return {
77
            "eval_template": self._eval_template,
78
        }
79

80
    def _update_prompts(self, prompts: PromptDictType) -> None:
81
        """Update prompts."""
82
        if "eval_template" in prompts:
83
            self._eval_template = prompts["eval_template"]
84

85
    async def aevaluate(
86
        self,
87
        query: Optional[str] = None,
88
        response: Optional[str] = None,
89
        contexts: Optional[Sequence[str]] = None,
90
        sleep_time_in_seconds: int = 0,
91
        **kwargs: Any,
92
    ) -> EvaluationResult:
93
        """Evaluate whether the query and response pair passes the guidelines."""
94
        del contexts  # Unused
95
        del kwargs  # Unused
96
        if query is None or response is None:
97
            raise ValueError("query and response must be provided")
98

99
        logger.debug("prompt: %s", self._eval_template)
100
        logger.debug("query: %s", query)
101
        logger.debug("response: %s", response)
102
        logger.debug("guidelines: %s", self._guidelines)
103

104
        await asyncio.sleep(sleep_time_in_seconds)
105

106
        eval_response = await self._service_context.llm.apredict(
107
            self._eval_template,
108
            query=query,
109
            response=response,
110
            guidelines=self._guidelines,
111
        )
112
        eval_data = self._output_parser.parse(eval_response)
113
        eval_data = cast(EvaluationData, eval_data)
114

115
        return EvaluationResult(
116
            query=query,
117
            response=response,
118
            passing=eval_data.passing,
119
            score=1.0 if eval_data.passing else 0.0,
120
            feedback=eval_data.feedback,
121
        )
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.