dream

Форк
0
161 строка · 5.5 Кб
1
import json
2
import logging
3
import os
4
import time
5

6
import sentry_sdk
7
import torch
8
from flask import Flask, request, jsonify
9
from peft import PeftModel, PeftConfig
10
from sentry_sdk.integrations.flask import FlaskIntegration
11
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
12

13
from common.prompts import META_GOALS_PROMPT
14
from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE
15

16

17
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
18

19
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
20
logger = logging.getLogger(__name__)
21

22
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
23
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
24
LANGUAGE = os.getenv("LANGUAGE", "EN")
25
NAMING = {
26
    "EN": ["AI", "Human"],
27
    "RU": ["Чат-бот", "Человек"],
28
}
29

30
app = Flask(__name__)
31
logging.getLogger("werkzeug").setLevel("WARNING")
32

33
DEFAULT_CONFIGS = {
34
    "transformers-lm-bloomz7b": json.load(open("common/generative_configs/default_generative_config.json", "r")),
35
    "transformers-lm-gptj": json.load(open("common/generative_configs/default_generative_config.json", "r")),
36
    "transformers-lm-oasst12b": json.load(open("common/generative_configs/default_generative_config.json", "r")),
37
}
38

39

40
def generate_responses(context, model, tokenizer, prompt, continue_last_uttr=False):
41
    outputs = []
42
    dialog_context = ""
43
    if prompt:
44
        dialog_context += prompt + "\n"
45
    s = len(context) % 2
46
    context = [f"{NAMING[LANGUAGE][(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]
47
    if continue_last_uttr:
48
        dialog_context += "\n".join(context)
49
    else:
50
        dialog_context += "\n".join(context) + f"\n{NAMING[LANGUAGE][0]}:"
51

52
    logger.info(f"context inside generate_responses seen as: {dialog_context}")
53
    data = tokenizer([dialog_context], return_tensors="pt")
54
    data = {k: v.to(model.device) for k, v in data.items() if k in ("input_ids", "attention_mask")}
55

56
    with torch.no_grad():
57
        chat_history_ids = model.generate(
58
            **data,
59
            generation_config=default_config,
60
        )
61
    if torch.cuda.is_available():
62
        chat_history_ids = chat_history_ids.cpu()
63
    for result in chat_history_ids:
64
        output = tokenizer.decode(result, skip_special_tokens=True)
65
        result_cut = output.replace(dialog_context + " ", "")
66
        result_cut = [x.strip() for x in GENERATIVE_ROBOT_TEMPLATE.split(result_cut) if x.strip()][0]
67
        logger.info(f"hypothesis: {result_cut}")
68
        outputs.append(result_cut)
69

70
    return outputs
71

72

73
try:
74
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
75

76
    default_config = GenerationConfig.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
77

78
    config = PeftConfig.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
79
    model = AutoModelForCausalLM.from_pretrained(
80
        config.base_model_name_or_path,
81
        torch_dtype=torch.float16,
82
        # load_in_8bit=True,
83
        # device_map="auto"
84
    )
85
    model = PeftModel.from_pretrained(model, PRETRAINED_MODEL_NAME_OR_PATH)
86
    model.eval()
87

88
    if torch.cuda.is_available():
89
        model.to("cuda")
90
        logger.info("transformers_peft_lm is set to run on cuda")
91

92
    example_response = generate_responses(
93
        ["What is the goal of SpaceX?"], model, tokenizer, "You are a SpaceX Assistant.", default_config
94
    )
95
    logger.info(f"example response: {example_response}")
96
    logger.info("transformers_peft_lm is ready")
97
except Exception as e:
98
    sentry_sdk.capture_exception(e)
99
    logger.exception(e)
100
    raise e
101

102

103
@app.route("/ping", methods=["POST"])
104
def ping():
105
    return "pong"
106

107

108
@app.route("/respond", methods=["POST"])
109
def respond():
110
    st_time = time.time()
111
    contexts = request.json.get("dialog_contexts", [])
112
    prompts = request.json.get("prompts", [])
113
    if len(contexts) > 0 and len(prompts) == 0:
114
        prompts = [""] * len(contexts)
115

116
    try:
117
        responses = []
118
        for context, prompt in zip(contexts, prompts):
119
            curr_responses = []
120
            outputs = generate_responses(context, model, tokenizer, prompt)
121
            for response in outputs:
122
                if len(response) >= 2:
123
                    curr_responses += [response]
124
                else:
125
                    curr_responses += [""]
126
            responses += [curr_responses]
127

128
    except Exception as exc:
129
        logger.exception(exc)
130
        sentry_sdk.capture_exception(exc)
131
        responses = [[""]] * len(contexts)
132

133
    logger.info(f"transformers_peft_lm output: {responses}")
134
    total_time = time.time() - st_time
135
    logger.info(f"transformers_peft_lm exec time: {total_time:.3f}s")
136
    return jsonify(responses)
137

138

139
@app.route("/generate_goals", methods=["POST"])
140
def generate_goals():
141
    st_time = time.time()
142

143
    prompts = request.json.get("prompts", None)
144
    prompts = [] if prompts is None else prompts
145

146
    try:
147
        responses = []
148
        for prompt in prompts:
149
            context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]
150
            goals_for_prompt = generate_responses(context, model, tokenizer, "")[0]
151
            logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")
152
            responses += [goals_for_prompt]
153

154
    except Exception as exc:
155
        logger.info(exc)
156
        sentry_sdk.capture_exception(exc)
157
        responses = [""] * len(prompts)
158

159
    total_time = time.time() - st_time
160
    logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")
161
    return jsonify(responses)
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.