dream

Форк
0
207 строк · 6.5 Кб
1
import logging
2
import time
3
import os
4
import random
5

6
from transformers import AutoTokenizer, AutoModelForCausalLM
7
import torch
8
from flask import Flask, request, jsonify
9
from healthcheck import HealthCheck
10
import sentry_sdk
11
from sentry_sdk.integrations.flask import FlaskIntegration
12

13
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
14

15

16
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
17
logger = logging.getLogger(__name__)
18

19
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get(
20
    "PRETRAINED_MODEL_NAME_OR_PATH", "DeepPavlov/rudialogpt3_medium_based_on_gpt2_v2"
21
)
22
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
23

24
cuda = torch.cuda.is_available()
25
if cuda:
26
    torch.cuda.set_device(0)
27
    device = "cuda"
28
else:
29
    device = "cpu"
30

31
try:
32
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
33
    model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH).to(device)
34
    model.eval()
35

36
    logger.info("dialogpt model is ready")
37
except Exception as e:
38
    sentry_sdk.capture_exception(e)
39
    logger.exception(e)
40
    raise e
41

42
logger.info(f"dialogpt is set to run on {device}")
43

44
SHORT_UTTERANCE_PROBA = 0.7
45
MAX_HISTORY_DEPTH = os.environ.get("MAX_HISTORY_DEPTH")
46
MAX_HISTORY_DEPTH = int(MAX_HISTORY_DEPTH) if MAX_HISTORY_DEPTH else MAX_HISTORY_DEPTH
47

48
params_default = {
49
    "max_length": 128,
50
    "no_repeat_ngram_size": 3,
51
    "do_sample": True,
52
    "top_k": 20,
53
    "top_p": 0.9,
54
    "temperature": 0.7,
55
    "num_return_sequences": 3,
56
    "device": device,
57
    "is_always_use_length": True,
58
}
59

60

61
def inputs_by_length(input_: dict, length_rep=None):
62
    if length_rep is None:
63
        length_rep = len(tokenizer.encode(input_["text"]))
64
    if params_default["is_always_use_length"]:
65
        if length_rep <= 15:
66
            length_param = "1"
67
        elif length_rep <= 50:
68
            length_param = "2"
69
        elif length_rep <= 256:
70
            length_param = "3"
71
        else:
72
            length_param = "-"
73
    else:
74
        length_param = "-"
75
    return f"|{input_['speaker']}|{length_param}|{input_['text']}"
76

77

78
def format_dialogue_with_target(context, context_lengths, context_depth=3, encode=False, tokenizer=None):
79
    """
80
    THE LAST UTTERANCE IN THE CONTEXT IS TARGET BOT'S UTTERANCE
81

82
    context: List(dict)
83
    context = [
84
        {"text": "speaker": "human"},
85
        {"text": "hi there", "speaker": "bot"},
86
        {"text": "how are you", "speaker": "human"},
87
        {"text": "great how are u", "speaker": "bot"},
88
    ]
89
    OR
90
    context = [
91
        "hi",
92
        "hi there",
93
        "how are you",
94
        "great how are u"
95
    ]
96
    """
97
    if len(context) > 0 and isinstance(context[0], str):
98
        context_len = len(context)
99
        # the last uttr is from BOT
100
        inputs = [{"text": uttr, "speaker": (context_len - uttr_id) % 2} for uttr_id, uttr in enumerate(context)]
101
        inputs = inputs[-context_depth:]
102
    else:
103
        inputs = [{"text": uttr["text"], "speaker": 1 if uttr["speaker"] == "bot" else 0} for uttr in context]
104
        inputs = inputs[-context_depth:]
105

106
    inputs_text = "".join([inputs_by_length(input_, inp_len) for input_, inp_len in zip(inputs, context_lengths)])
107

108
    if encode:
109
        # if encode, return encoded context
110
        inputs_token_ids = tokenizer.encode(inputs_text, return_tensors="pt")
111
        return inputs_token_ids
112

113
    return inputs_text
114

115

116
def format_dialogue_for_inference(context, context_depth=4, encode=False, tokenizer=None):
117
    """
118
    THE LAST UTTERANCE IN THE CONTEXT IS TARGET HUMAN'S UTTERANCE
119

120
    context: List(dict)
121
    context = [
122
        {"text": "speaker": "human"},
123
        {"text": "hi there", "speaker": "bot"},
124
        {"text": "how are you", "speaker": "human"},
125
    ]
126
    OR
127
    context = [
128
        "hi",
129
        "hi there",
130
        "how are you",
131
    ]
132
    """
133
    if len(context) > 0 and isinstance(context[0], str):
134
        context_len = len(context)
135
        # the last uttr is from HUMAN
136
        inputs = [{"text": uttr, "speaker": (context_len - uttr_id - 1) % 2} for uttr_id, uttr in enumerate(context)]
137
        inputs = inputs[-context_depth:]
138
    else:
139
        inputs = [{"text": uttr["text"], "speaker": 1 if uttr["speaker"] == "bot" else 0} for uttr in context]
140
        inputs = inputs[-context_depth:]
141

142
    inputs_text = "".join([inputs_by_length(input_) for input_ in inputs])
143
    length = "2" if random.uniform(0, 1) > SHORT_UTTERANCE_PROBA else "1"
144
    inputs_text += f"|1|{length}|"
145

146
    if encode:
147
        # if encode, return encoded context
148
        inputs_token_ids = tokenizer.encode(inputs_text, return_tensors="pt")
149
        return inputs_token_ids
150

151
    return inputs_text
152

153

154
app = Flask(__name__)
155
health = HealthCheck(app, "/healthcheck")
156
logging.getLogger("werkzeug").setLevel("WARNING")
157

158

159
@app.route("/ping", methods=["POST"])
160
def ping():
161
    return "pong"
162

163

164
def generate(context, num_return_sequences, context_depth):
165
    bot_input_ids = format_dialogue_for_inference(
166
        context, context_depth=context_depth, encode=True, tokenizer=tokenizer
167
    )
168
    bot_input_ids = bot_input_ids.to(device)
169
    params_default["num_return_sequences"] = num_return_sequences
170

171
    chat_history_ids = model.generate(bot_input_ids, pad_token_id=tokenizer.eos_token_id, **params_default)
172
    resp_tokens = chat_history_ids[:, bot_input_ids.shape[-1] :]
173
    outputs = [tokenizer.decode(x, skip_special_tokens=True) for x in resp_tokens]
174
    outputs = [x.split("|")[0] for x in outputs]
175

176
    return outputs
177

178

179
@app.route("/respond", methods=["POST"])
180
def respond():
181
    st_time = time.time()
182

183
    dialog_contexts = request.json.get("dialog_contexts", [])
184
    num_return_sequences = request.json.get("num_return_sequences", 3)
185

186
    try:
187
        batch_generated_responses = []
188
        for context in dialog_contexts:
189
            # context is a list of dicts, each dict contains text and speaker label
190
            # context = [{"text": "utterance text", "speaker": "human"}, ...]
191
            logger.info(f"dialogpt inputs: {context[-MAX_HISTORY_DEPTH:]}")
192

193
            hypotheses = generate(
194
                context[-MAX_HISTORY_DEPTH:], num_return_sequences=num_return_sequences, context_depth=MAX_HISTORY_DEPTH
195
            )
196
            logger.info(f"dialogpt hypotheses: {hypotheses}")
197
            batch_generated_responses.append(hypotheses)
198

199
    except Exception as exc:
200
        logger.exception(exc)
201
        sentry_sdk.capture_exception(exc)
202
        batch_generated_responses = [[]] * len(dialog_contexts)
203

204
    total_time = time.time() - st_time
205
    logger.info(f"dialogpt exec time: {total_time:.3f}s")
206

207
    return jsonify(batch_generated_responses)
208

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.