promptflow
64 строки · 1.8 Кб
1import os
2
3from dotenv import load_dotenv
4from openai.version import VERSION as OPENAI_VERSION
5
6from promptflow.tracing import trace
7
8
9def get_client():
10if OPENAI_VERSION.startswith("0."):
11raise Exception(
12"Please upgrade your OpenAI package to version >= 1.0.0 or using the command: pip install --upgrade openai."
13)
14api_key = os.environ.get("OPENAI_API_KEY", None)
15if api_key:
16from openai import OpenAI
17
18return OpenAI()
19else:
20from openai import AzureOpenAI
21
22return AzureOpenAI(
23api_version=os.environ.get("OPENAI_API_VERSION", "2023-07-01-preview")
24)
25
26
27@trace
28def my_llm_tool(
29prompt: str,
30# for AOAI, deployment name is customized by user, not model name.
31deployment_name: str,
32max_tokens: int = 120,
33temperature: float = 1.0,
34top_p: float = 1.0,
35n: int = 1,
36) -> str:
37if "OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_API_KEY" not in os.environ:
38# load environment variables from .env file
39load_dotenv()
40
41if "OPENAI_API_KEY" not in os.environ and "AZURE_OPENAI_API_KEY" not in os.environ:
42raise Exception(
43"Please specify environment variables: OPENAI_API_KEY or AZURE_OPENAI_API_KEY"
44)
45messages = [{"content": prompt, "role": "system"}]
46response = get_client().chat.completions.create(
47messages=messages,
48model=deployment_name,
49max_tokens=int(max_tokens),
50temperature=float(temperature),
51top_p=float(top_p),
52n=int(n),
53)
54
55# get first element because prompt is single.
56return response.choices[0].message.content
57
58
59if __name__ == "__main__":
60result = my_llm_tool(
61prompt="Write a simple Hello, world! program that displays the greeting message.",
62deployment_name="gpt-35-turbo",
63)
64print(result)
65