promptflow
70 строк · 2.1 Кб
1import asyncio
2import os
3from pathlib import Path
4
5from promptflow.tracing import trace
6from promptflow.core import AzureOpenAIModelConfiguration, Prompty
7
8BASE_DIR = Path(__file__).absolute().parent
9
10
11def log(message: str):
12verbose = os.environ.get("VERBOSE", "false")
13if verbose.lower() == "true":
14print(message, flush=True)
15
16
17class ChatFlow:
18def __init__(
19self, model_config: AzureOpenAIModelConfiguration, max_total_token=1100
20):
21self.model_config = model_config
22self.max_total_token = max_total_token
23
24@trace
25async def __call__(
26self, question: str = "What is ChatGPT?", chat_history: list = None
27) -> str:
28"""Flow entry function."""
29
30prompty = Prompty.load(
31source=BASE_DIR / "chat.prompty",
32model={"configuration": self.model_config},
33)
34
35chat_history = chat_history or []
36# Try to render the prompt with token limit and reduce the history count if it fails
37while len(chat_history) > 0:
38token_count = prompty.estimate_token_count(
39question=question, chat_history=chat_history
40)
41if token_count > self.max_total_token:
42chat_history = chat_history[1:]
43log(
44f"Reducing chat history count to {len(chat_history)} to fit token limit"
45)
46else:
47break
48
49# output is a generator of string as prompty enabled stream parameter
50for output in prompty(question=question, chat_history=chat_history):
51yield output
52
53
54if __name__ == "__main__":
55from promptflow.tracing import start_trace
56
57start_trace()
58config = AzureOpenAIModelConfiguration(
59connection="open_ai_connection", azure_deployment="gpt-35-turbo"
60)
61flow = ChatFlow(model_config=config)
62result = flow("What's Azure Machine Learning?", [])
63
64# print result in stream manner
65async def consume_result():
66async for output in result:
67print(output, end="")
68await asyncio.sleep(0.01)
69
70asyncio.run(consume_result())
71