promptflow

Форк
0
70 строк · 2.2 Кб
1
from dataclasses import dataclass
2

3
from langchain.evaluation import load_evaluator
4
from langchain_community.chat_models import AzureChatOpenAI, ChatAnthropic
5

6
from promptflow.client import PFClient
7
from promptflow.connections import CustomConnection
8
from promptflow.tracing import trace
9

10

11
@dataclass
12
class Result:
13
    reasoning: str
14
    value: str
15
    score: float
16

17

18
class LangChainEvaluator:
19
    def __init__(self, custom_connection: CustomConnection):
20
        self.custom_connection = custom_connection
21

22
        # create llm according to the secrets in custom connection
23
        if "anthropic_api_key" in self.custom_connection.secrets:
24
            self.llm = ChatAnthropic(
25
                temperature=0,
26
                anthropic_api_key=self.custom_connection.secrets["anthropic_api_key"],
27
            )
28
        elif "openai_api_key" in self.custom_connection.secrets:
29
            self.llm = AzureChatOpenAI(
30
                deployment_name="gpt-35-turbo",
31
                openai_api_key=self.custom_connection.secrets["openai_api_key"],
32
                azure_endpoint=self.custom_connection.configs["azure_endpoint"],
33
                openai_api_type="azure",
34
                openai_api_version="2023-07-01-preview",
35
                temperature=0,
36
            )
37
        else:
38
            raise ValueError("No valid API key found in the connection.")
39
        # evaluate with langchain evaluator for conciseness
40
        self.evaluator = load_evaluator(
41
            "criteria", llm=self.llm, criteria="conciseness"
42
        )
43

44
    @trace
45
    def __call__(
46
        self,
47
        input: str,
48
        prediction: str,
49
    ) -> Result:
50
        """Evaluate with langchain evaluator."""
51

52
        eval_result = self.evaluator.evaluate_strings(
53
            prediction=prediction, input=input
54
        )
55
        return Result(**eval_result)
56

57

58
if __name__ == "__main__":
59
    from promptflow.tracing import start_trace
60

61
    start_trace()
62
    pf = PFClient()
63
    connection = pf.connections.get(name="my_llm_connection")
64
    evaluator = LangChainEvaluator(custom_connection=connection)
65
    result = evaluator(
66
        prediction="What's 2+2? That's an elementary question. "
67
        "The answer you're looking for is that two and two is four.",
68
        input="What's 2+2?",
69
    )
70
    print(result)
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.