promptflow
70 строк · 2.2 Кб
1from dataclasses import dataclass
2
3from langchain.evaluation import load_evaluator
4from langchain_community.chat_models import AzureChatOpenAI, ChatAnthropic
5
6from promptflow.client import PFClient
7from promptflow.connections import CustomConnection
8from promptflow.tracing import trace
9
10
11@dataclass
12class Result:
13reasoning: str
14value: str
15score: float
16
17
18class LangChainEvaluator:
19def __init__(self, custom_connection: CustomConnection):
20self.custom_connection = custom_connection
21
22# create llm according to the secrets in custom connection
23if "anthropic_api_key" in self.custom_connection.secrets:
24self.llm = ChatAnthropic(
25temperature=0,
26anthropic_api_key=self.custom_connection.secrets["anthropic_api_key"],
27)
28elif "openai_api_key" in self.custom_connection.secrets:
29self.llm = AzureChatOpenAI(
30deployment_name="gpt-35-turbo",
31openai_api_key=self.custom_connection.secrets["openai_api_key"],
32azure_endpoint=self.custom_connection.configs["azure_endpoint"],
33openai_api_type="azure",
34openai_api_version="2023-07-01-preview",
35temperature=0,
36)
37else:
38raise ValueError("No valid API key found in the connection.")
39# evaluate with langchain evaluator for conciseness
40self.evaluator = load_evaluator(
41"criteria", llm=self.llm, criteria="conciseness"
42)
43
44@trace
45def __call__(
46self,
47input: str,
48prediction: str,
49) -> Result:
50"""Evaluate with langchain evaluator."""
51
52eval_result = self.evaluator.evaluate_strings(
53prediction=prediction, input=input
54)
55return Result(**eval_result)
56
57
58if __name__ == "__main__":
59from promptflow.tracing import start_trace
60
61start_trace()
62pf = PFClient()
63connection = pf.connections.get(name="my_llm_connection")
64evaluator = LangChainEvaluator(custom_connection=connection)
65result = evaluator(
66prediction="What's 2+2? That's an elementary question. "
67"The answer you're looking for is that two and two is four.",
68input="What's 2+2?",
69)
70print(result)
71