dream

Форк
0
115 строк · 4.3 Кб
1
import logging
2
import time
3
import os
4
import re
5
import json
6
import sentry_sdk
7
import torch
8
from flask import Flask, request, jsonify
9
from sentry_sdk.integrations.flask import FlaskIntegration
10
from transformers import GPT2LMHeadModel, GPT2Tokenizer
11
from itertools import cycle
12

13
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
14

15

16
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
17
logger = logging.getLogger(__name__)
18

19
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
20
CONFIG_NAME = os.environ.get("CONFIG_NAME")
21
N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE"))
22
MAX_HISTORY_DEPTH = int(os.environ.get("MAX_HISTORY_DEPTH"))
23
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
24

25
DEFAULT_CONFIDENCE = 0.9
26
ZERO_CONFIDENCE = 0.0
27

28
with open(CONFIG_NAME, "r") as f:
29
    generation_params = json.load(f)
30
generation_params["num_return_sequences"] = N_HYPOTHESES_TO_GENERATE
31

32
NICK_COMPILED = re.compile(r"\@[^\s]+\b")
33
SPECIFIC_SYMBOLS = r"[#$%&()*+/;<=>@\^_`{|}~\[\]�]"
34
SPECIFIC_WORDS_COMPILED = re.compile(rf"{SPECIFIC_SYMBOLS}[a-zA-Z]+{SPECIFIC_SYMBOLS}")
35
URLS_COMPILED = re.compile(r"(http[^\s]+\b|<a href[^>]+>)")
36
ANYTHING_EXCEPT_OF_LETTERS_SPACE_AND_PUNCT_COMPILED = re.compile(SPECIFIC_SYMBOLS)
37
NEW_UTTERANCE = ""
38

39
try:
40
    tokenizer = GPT2Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
41
    model = GPT2LMHeadModel.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
42
    if torch.cuda.is_available():
43
        model.to("cuda")
44
        logger.info("gpt2-generator is set to run on cuda")
45

46
    logger.info("gpt2-generator is ready")
47
except Exception as e:
48
    sentry_sdk.capture_exception(e)
49
    logger.exception(e)
50
    raise e
51

52
app = Flask(__name__)
53
logging.getLogger("werkzeug").setLevel("WARNING")
54

55

56
def generate_response(context, model, tokenizer):
57
    text = "\n".join(
58
        list(map(lambda x: NEW_UTTERANCE.join(x), zip(cycle(["", ""]), context[-MAX_HISTORY_DEPTH:] + [""])))
59
    )
60
    bot_input_ids = tokenizer.encode(text, return_tensors="pt")
61
    with torch.no_grad():
62
        if torch.cuda.is_available():
63
            bot_input_ids = bot_input_ids.to("cuda")
64

65
        chat_history_ids = model.generate(bot_input_ids, pad_token_id=tokenizer.eos_token_id, **generation_params)
66
        if torch.cuda.is_available():
67
            chat_history_ids = chat_history_ids.cpu()
68

69
    outputs = [
70
        tokenizer.decode(x, skip_special_tokens=True)[len(text) :].lstrip().split("\n")[0] for x in chat_history_ids
71
    ]
72
    outputs = [re.sub(NICK_COMPILED, "", response) for response in outputs]
73
    outputs = [re.sub(URLS_COMPILED, "", response) for response in outputs]
74
    outputs = [re.sub(SPECIFIC_WORDS_COMPILED, "", response) for response in outputs]
75
    outputs = [ANYTHING_EXCEPT_OF_LETTERS_SPACE_AND_PUNCT_COMPILED.sub("", response).strip() for response in outputs]
76
    return outputs
77

78

79
@app.route("/respond", methods=["POST"])
80
def respond():
81
    st_time = time.time()
82
    contexts = request.json.get("utterances_histories", [])
83
    if len(contexts) == 0:
84
        contexts = request.json.get("dialog_contexts", [])
85

86
    try:
87
        responses = []
88
        confidences = []
89
        for context in contexts:
90
            curr_responses = []
91
            curr_confidences = []
92
            preds = generate_response(context, model, tokenizer)
93
            for response in preds:
94
                response = re.sub(NICK_COMPILED, "", response).strip()
95

96
                if len(response) > 3:
97
                    # drop too short responses
98
                    curr_responses += [response]
99
                    curr_confidences += [DEFAULT_CONFIDENCE]
100
                else:
101
                    curr_responses += [""]
102
                    curr_confidences += [ZERO_CONFIDENCE]
103
            logger.info(f"gpt2-generator for context: `{context}`\n returns: {curr_responses}")
104
            responses += [curr_responses]
105
            confidences += [curr_confidences]
106

107
    except Exception as exc:
108
        logger.exception(exc)
109
        sentry_sdk.capture_exception(exc)
110
        responses = [[""]] * len(contexts)
111
        confidences = [[ZERO_CONFIDENCE]] * len(contexts)
112

113
    total_time = time.time() - st_time
114
    logger.info(f"gpt2-generator exec time: {total_time:.3f}s")
115
    return jsonify(list(zip(responses, confidences)))
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.