dream

Форк
0
159 строк · 6.4 Кб
1
import json
2
import logging
3
import os
4
import time
5
import sentry_sdk
6

7
from openai import OpenAI
8
from common.prompts import META_GOALS_PROMPT
9
from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE
10
from flask import Flask, request, jsonify
11
from sentry_sdk.integrations.flask import FlaskIntegration
12

13

14
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
15

16

17
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
18
logger = logging.getLogger(__name__)
19

20
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
21
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
22
NAMING = ["AI", "Human"]
23
CHATGPT_ROLES = ["assistant", "user"]
24

25
app = Flask(__name__)
26
logging.getLogger("werkzeug").setLevel("WARNING")
27
DEFAULT_CONFIGS = {
28
    "text-davinci-003": json.load(open("common/generative_configs/openai-text-davinci-003.json", "r")),
29
    "gpt-3.5-turbo": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),
30
    "gpt-3.5-turbo-16k": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),
31
    "gpt-4": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),
32
    "gpt-4-32k": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),
33
    "gpt-4-1106-preview": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),
34
}
35
CHAT_COMPLETION_MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4-1106-preview"]
36

37

38
def generate_responses(context, openai_api_key, openai_org, prompt, generation_params, continue_last_uttr=False):
39
    outputs = []
40

41
    assert openai_api_key, logger.error("Error: OpenAI API key is not specified in env")
42
    client = OpenAI(api_key=openai_api_key, organization=openai_org if openai_org else None)
43

44
    if PRETRAINED_MODEL_NAME_OR_PATH in CHAT_COMPLETION_MODELS:
45
        logger.info("Use special chat completion endpoint")
46
        s = len(context) % 2
47
        messages = [
48
            {"role": "system", "content": prompt},
49
        ]
50
        messages += [
51
            {
52
                "role": f"{CHATGPT_ROLES[(s + uttr_id) % 2]}",
53
                "content": uttr,
54
            }
55
            for uttr_id, uttr in enumerate(context)
56
        ]
57
        logger.info(f"context inside generate_responses seen as: {messages}")
58
        response = client.chat.completions.create(
59
            model=PRETRAINED_MODEL_NAME_OR_PATH, messages=messages, **generation_params
60
        )
61
    else:
62
        dialog_context = ""
63
        if prompt:
64
            dialog_context += prompt + "\n"
65
        s = len(context) % 2
66
        context = [f"{NAMING[(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]
67
        if continue_last_uttr:
68
            dialog_context += "\n".join(context)
69
        else:
70
            dialog_context += "\n".join(context) + f"\n{NAMING[0]}:"
71
        logger.info(f"context inside generate_responses seen as: {dialog_context}")
72
        response = client.completions.create(
73
            model=PRETRAINED_MODEL_NAME_OR_PATH, prompt=dialog_context, **generation_params
74
        )
75

76
    response = response.model_dump()
77
    outputs = [
78
        resp["message"]["content"].strip() if "message" in resp else resp.get("text", "").strip()
79
        for resp in response["choices"]
80
    ]
81

82
    if PRETRAINED_MODEL_NAME_OR_PATH not in CHAT_COMPLETION_MODELS:
83
        # post-processing of the responses by all models except of ChatGPT
84
        outputs = [GENERATIVE_ROBOT_TEMPLATE.sub("\n", resp).strip() for resp in outputs]
85
    return outputs
86

87

88
@app.route("/ping", methods=["POST"])
89
def ping():
90
    return "pong"
91

92

93
@app.route("/respond", methods=["POST"])
94
def respond():
95
    st_time = time.time()
96
    contexts = request.json.get("dialog_contexts", [])
97
    prompts = request.json.get("prompts", [])
98
    configs = request.json.get("configs", None)
99
    configs = [None] * len(prompts) if configs is None else configs
100
    configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]
101
    if len(contexts) > 0 and len(prompts) == 0:
102
        prompts = [""] * len(contexts)
103
    openai_api_keys = request.json.get("openai_api_keys", [])
104
    openai_orgs = request.json.get("openai_api_organizations", None)
105
    openai_orgs = [None] * len(contexts) if openai_orgs is None else openai_orgs
106

107
    try:
108
        responses = []
109
        for context, openai_api_key, openai_org, prompt, config in zip(
110
            contexts, openai_api_keys, openai_orgs, prompts, configs
111
        ):
112
            curr_responses = []
113
            outputs = generate_responses(context, openai_api_key, openai_org, prompt, config)
114
            for response in outputs:
115
                if len(response) >= 2:
116
                    curr_responses += [response]
117
                else:
118
                    curr_responses += [""]
119
            responses += [curr_responses]
120

121
    except Exception as exc:
122
        logger.exception(exc)
123
        sentry_sdk.capture_exception(exc)
124
        responses = [[""]] * len(contexts)
125

126
    logger.info(f"openai-api result: {responses}")
127
    total_time = time.time() - st_time
128
    logger.info(f"openai-api exec time: {total_time:.3f}s")
129
    return jsonify(responses)
130

131

132
@app.route("/generate_goals", methods=["POST"])
133
def generate_goals():
134
    st_time = time.time()
135

136
    prompts = request.json.get("prompts", None)
137
    prompts = [] if prompts is None else prompts
138
    configs = request.json.get("configs", None)
139
    configs = [None] * len(prompts) if configs is None else configs
140
    configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]
141
    openai_api_keys = request.json.get("openai_api_keys", [])
142
    openai_orgs = request.json.get("openai_api_organizations", None)
143
    openai_orgs = [None] * len(prompts) if openai_orgs is None else openai_orgs
144
    try:
145
        responses = []
146
        for openai_api_key, openai_org, prompt, config in zip(openai_api_keys, openai_orgs, prompts, configs):
147
            context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]
148
            goals_for_prompt = generate_responses(context, openai_api_key, openai_org, "", config)[0]
149
            logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")
150
            responses += [goals_for_prompt]
151

152
    except Exception as exc:
153
        logger.info(exc)
154
        sentry_sdk.capture_exception(exc)
155
        responses = [""] * len(prompts)
156

157
    total_time = time.time() - st_time
158
    logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")
159
    return jsonify(responses)
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.