promptflow

Форк
0
69 строк · 2.1 Кб
1
import json
2

3
from typing import TypedDict
4
from pathlib import Path
5

6
from jinja2 import Template
7

8
from promptflow.tracing import trace
9
from promptflow.core import AzureOpenAIModelConfiguration
10
from promptflow.core._flow import Prompty
11

12
BASE_DIR = Path(__file__).absolute().parent
13

14

15
@trace
16
def load_prompt(jinja2_template: str, code: str, examples: list) -> str:
17
    """Load prompt function."""
18
    with open(BASE_DIR / jinja2_template, "r", encoding="utf-8") as f:
19
        tmpl = Template(f.read(), trim_blocks=True, keep_trailing_newline=True)
20
        prompt = tmpl.render(code=code, examples=examples)
21
        return prompt
22

23

24
class Result(TypedDict):
25
    correctness: float
26
    readability: float
27
    explanation: str
28

29

30
class CodeEvaluator:
31
    def __init__(self, model_config: AzureOpenAIModelConfiguration):
32
        self.model_config = model_config
33

34
    def __call__(self, code: str) -> Result:
35
        """Evaluate the code based on correctness, readability."""
36
        prompty = Prompty.load(
37
            source=BASE_DIR / "eval_code_quality.prompty",
38
            model={"configuration": self.model_config},
39
        )
40
        output = prompty(code=code)
41
        output = json.loads(output)
42
        output = Result(**output)
43
        return output
44

45
    def __aggregate__(self, line_results: list) -> dict:
46
        """Aggregate the results."""
47
        total = len(line_results)
48
        avg_correctness = sum(int(r["correctness"]) for r in line_results) / total
49
        avg_readability = sum(int(r["readability"]) for r in line_results) / total
50
        return {
51
            "average_correctness": avg_correctness,
52
            "average_readability": avg_readability,
53
            "total": total,
54
        }
55

56

57
if __name__ == "__main__":
58
    from promptflow.tracing import start_trace
59

60
    start_trace()
61
    model_config = AzureOpenAIModelConfiguration(
62
        connection="open_ai_connection",
63
        azure_deployment="gpt-35-turbo",
64
    )
65
    evaluator = CodeEvaluator(model_config)
66
    result = evaluator('print("Hello, world!")')
67
    print(result)
68
    aggregate_result = evaluator.__aggregate__([result])
69
    print(aggregate_result)
70

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.