dream

Форк
0
165 строк · 5.5 Кб
1
import logging
2
import time
3
import os
4

5
import sentry_sdk
6
import torch
7
from flask import Flask, request, jsonify
8
from sentry_sdk.integrations.flask import FlaskIntegration
9
from transformers import GPT2Tokenizer, GPT2LMHeadModel
10
from transformers import BartForConditionalGeneration, BartTokenizer
11
from string import punctuation
12

13
import nltk
14
from nltk.corpus import stopwords
15
import re
16
from nltk.tokenize import sent_tokenize
17

18
nltk.download("stopwords")
19
nltk.download("punkt")
20
stop_words = stopwords.words("english")
21

22
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
23

24
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
25
logger = logging.getLogger(__name__)
26

27
DEFAULT_CONFIDENCE = 1.0
28
ZERO_CONFIDENCE = 0.0
29
BART_MODEL_NAME = os.environ.get("BART_MODEL_NAME")
30
FINETUNED_MODEL_NAME = os.environ.get("FINETUNED_MODEL_NAME")
31
pattern = re.compile(r"\(.*?\)")
32
continue_phrase = " Should I continue?"
33

34
try:
35
    tokenizer = GPT2Tokenizer.from_pretrained(FINETUNED_MODEL_NAME)
36
    tokenizer.padding_side = "left"
37
    tokenizer.pad_token = tokenizer.eos_token
38
    model = GPT2LMHeadModel.from_pretrained(FINETUNED_MODEL_NAME)
39
    bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL_NAME, forced_bos_token_id=0)
40
    bart_tok = BartTokenizer.from_pretrained(BART_MODEL_NAME)
41
    if torch.cuda.is_available():
42
        device = "cuda"
43
    else:
44
        device = "cpu"
45
    model.to(device)
46
    logger.info(f"prompt_storygpt is set to run on {device}")
47
    logger.info("prompt_storygpt is ready")
48
except Exception as e:
49
    sentry_sdk.capture_exception(e)
50
    logger.exception(e)
51
    raise e
52

53
app = Flask(__name__)
54
logging.getLogger("werkzeug").setLevel("WARNING")
55

56

57
def generate_part(texts, max_len, temp, num_sents, first):
58
    if not first:
59
        texts = [text + " At the end," for text in texts]
60
    encoding = tokenizer(texts, padding=True, return_tensors="pt").to(device)
61
    with torch.no_grad():
62
        generated_ids = model.generate(
63
            **encoding,
64
            max_length=max_len,
65
            length_penalty=-100.0,
66
            temperature=temp,
67
            do_sample=True,
68
        )
69
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
70

71
    return_texts = []
72
    for text in generated_texts:
73
        text = pattern.sub("", text)  # delete everything in ()
74
        text = text.replace(" .", ".").replace("..", ".").replace("..", ".")
75
        sents = sent_tokenize(text)
76
        text = " ".join(sents[:num_sents])
77
        if text[-1] not in ".!?":
78
            if text[-1] in punctuation:
79
                text = text[:-1]
80
            text += "."
81
        return_texts.append(text)
82
    return return_texts
83

84

85
def fill_mask(masked_phrases):
86
    batch = bart_tok(masked_phrases, return_tensors="pt")
87
    generated_ids = bart_model.generate(batch["input_ids"])
88
    filled = bart_tok.batch_decode(generated_ids, skip_special_tokens=True)
89
    logger.info(f"Filled masks: {filled}")
90
    return filled
91

92

93
def generate_first_part(context):
94
    """
95
    Parameters
96
    context: List[str]
97
        a list consisting of nouns chosen from spacy_nounphrases annotator
98
    Returns
99
    final_text: List[str]
100
        generated stories
101
    """
102
    nouns = context
103
    logger.info(f"Topic in StoryGPT service: {nouns}")
104
    masked_phrases = []
105
    for noun in nouns:
106
        masked_phrases.append(f"Let me share a story about {noun}. I <mask> {noun}")
107
    filled = fill_mask(masked_phrases)
108

109
    st_time = time.time()
110
    final_texts = generate_part(filled, 50, 0.8, 4, first=True)
111
    total_time = time.time() - st_time
112
    logger.info(f"Time for first part generation: {total_time:.3f}s")
113
    final_texts = ["Ok,  " + text + continue_phrase for text in final_texts]
114
    logger.info(f"First parts generated: {final_texts}")
115
    return final_texts
116

117

118
def generate_second_part(context):
119
    first_texts = context
120
    logger.info(f"Received first part: {first_texts}")
121
    first_texts = [text.replace(continue_phrase, "") for text in first_texts]
122
    st_time = time.time()
123
    final_texts = generate_part(first_texts, 100, 0.8, 5, first=False)
124
    final_texts = [final_texts[i].replace(first_texts[i], "") for i in range(len(final_texts))]
125
    logger.info(f"Generated: {final_texts}")
126
    total_time = time.time() - st_time
127
    logger.info(f"Time for generation: {total_time:.3f}s")
128
    return final_texts
129

130

131
def generate_response(context):
132
    texts, first_part = context
133
    if first_part:
134
        replies = generate_first_part(texts)  # text is a list of nouns
135
    else:
136
        replies = generate_second_part(texts)  # text is a list of first part texts
137
    return replies
138

139

140
@app.route("/respond", methods=["POST"])
141
def respond():
142
    st_time = time.time()
143
    contexts = request.json.get("utterances_histories", [])
144

145
    try:
146
        tmp_responses = generate_response(contexts)
147
        responses = []
148
        confidences = []
149
        for response in tmp_responses:
150
            if len(response) > 3:
151
                # drop too short responses
152
                responses += [response]
153
                confidences += [DEFAULT_CONFIDENCE]
154
            else:
155
                responses += [""]
156
                confidences += [ZERO_CONFIDENCE]
157
    except Exception as exc:
158
        logger.exception(exc)
159
        sentry_sdk.capture_exception(exc)
160
        responses = [""] * len(contexts)
161
        confidences = [ZERO_CONFIDENCE] * len(contexts)
162

163
    total_time = time.time() - st_time
164
    logger.info(f"Prompt storyGPT exec time: {total_time:.3f}s")
165
    return jsonify(list(zip(responses, confidences)))
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.