promptflow
69 строк · 2.1 Кб
1import json2
3from typing import TypedDict4from pathlib import Path5
6from jinja2 import Template7
8from promptflow.tracing import trace9from promptflow.core import AzureOpenAIModelConfiguration10from promptflow.core._flow import Prompty11
12BASE_DIR = Path(__file__).absolute().parent13
14
15@trace
16def load_prompt(jinja2_template: str, code: str, examples: list) -> str:17"""Load prompt function."""18with open(BASE_DIR / jinja2_template, "r", encoding="utf-8") as f:19tmpl = Template(f.read(), trim_blocks=True, keep_trailing_newline=True)20prompt = tmpl.render(code=code, examples=examples)21return prompt22
23
24class Result(TypedDict):25correctness: float26readability: float27explanation: str28
29
30class CodeEvaluator:31def __init__(self, model_config: AzureOpenAIModelConfiguration):32self.model_config = model_config33
34def __call__(self, code: str) -> Result:35"""Evaluate the code based on correctness, readability."""36prompty = Prompty.load(37source=BASE_DIR / "eval_code_quality.prompty",38model={"configuration": self.model_config},39)40output = prompty(code=code)41output = json.loads(output)42output = Result(**output)43return output44
45def __aggregate__(self, line_results: list) -> dict:46"""Aggregate the results."""47total = len(line_results)48avg_correctness = sum(int(r["correctness"]) for r in line_results) / total49avg_readability = sum(int(r["readability"]) for r in line_results) / total50return {51"average_correctness": avg_correctness,52"average_readability": avg_readability,53"total": total,54}55
56
57if __name__ == "__main__":58from promptflow.tracing import start_trace59
60start_trace()61model_config = AzureOpenAIModelConfiguration(62connection="open_ai_connection",63azure_deployment="gpt-35-turbo",64)65evaluator = CodeEvaluator(model_config)66result = evaluator('print("Hello, world!")')67print(result)68aggregate_result = evaluator.__aggregate__([result])69print(aggregate_result)70