promptflow

Форк
0
70 строк · 2.1 Кб
1
import asyncio
2
import os
3
from pathlib import Path
4

5
from promptflow.tracing import trace
6
from promptflow.core import AzureOpenAIModelConfiguration, Prompty
7

8
BASE_DIR = Path(__file__).absolute().parent
9

10

11
def log(message: str):
12
    verbose = os.environ.get("VERBOSE", "false")
13
    if verbose.lower() == "true":
14
        print(message, flush=True)
15

16

17
class ChatFlow:
18
    def __init__(
19
        self, model_config: AzureOpenAIModelConfiguration, max_total_token=1100
20
    ):
21
        self.model_config = model_config
22
        self.max_total_token = max_total_token
23

24
    @trace
25
    async def __call__(
26
        self, question: str = "What is ChatGPT?", chat_history: list = None
27
    ) -> str:
28
        """Flow entry function."""
29

30
        prompty = Prompty.load(
31
            source=BASE_DIR / "chat.prompty",
32
            model={"configuration": self.model_config},
33
        )
34

35
        chat_history = chat_history or []
36
        # Try to render the prompt with token limit and reduce the history count if it fails
37
        while len(chat_history) > 0:
38
            token_count = prompty.estimate_token_count(
39
                question=question, chat_history=chat_history
40
            )
41
            if token_count > self.max_total_token:
42
                chat_history = chat_history[1:]
43
                log(
44
                    f"Reducing chat history count to {len(chat_history)} to fit token limit"
45
                )
46
            else:
47
                break
48

49
        # output is a generator of string as prompty enabled stream parameter
50
        for output in prompty(question=question, chat_history=chat_history):
51
            yield output
52

53

54
if __name__ == "__main__":
55
    from promptflow.tracing import start_trace
56

57
    start_trace()
58
    config = AzureOpenAIModelConfiguration(
59
        connection="open_ai_connection", azure_deployment="gpt-35-turbo"
60
    )
61
    flow = ChatFlow(model_config=config)
62
    result = flow("What's Azure Machine Learning?", [])
63

64
    # print result in stream manner
65
    async def consume_result():
66
        async for output in result:
67
            print(output, end="")
68
            await asyncio.sleep(0.01)
69

70
    asyncio.run(consume_result())
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.