dream
165 строк · 5.5 Кб
1import logging
2import time
3import os
4
5import sentry_sdk
6import torch
7from flask import Flask, request, jsonify
8from sentry_sdk.integrations.flask import FlaskIntegration
9from transformers import GPT2Tokenizer, GPT2LMHeadModel
10from transformers import BartForConditionalGeneration, BartTokenizer
11from string import punctuation
12
13import nltk
14from nltk.corpus import stopwords
15import re
16from nltk.tokenize import sent_tokenize
17
18nltk.download("stopwords")
19nltk.download("punkt")
20stop_words = stopwords.words("english")
21
22sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
23
24logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
25logger = logging.getLogger(__name__)
26
27DEFAULT_CONFIDENCE = 1.0
28ZERO_CONFIDENCE = 0.0
29BART_MODEL_NAME = os.environ.get("BART_MODEL_NAME")
30FINETUNED_MODEL_NAME = os.environ.get("FINETUNED_MODEL_NAME")
31pattern = re.compile(r"\(.*?\)")
32continue_phrase = " Should I continue?"
33
34try:
35tokenizer = GPT2Tokenizer.from_pretrained(FINETUNED_MODEL_NAME)
36tokenizer.padding_side = "left"
37tokenizer.pad_token = tokenizer.eos_token
38model = GPT2LMHeadModel.from_pretrained(FINETUNED_MODEL_NAME)
39bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL_NAME, forced_bos_token_id=0)
40bart_tok = BartTokenizer.from_pretrained(BART_MODEL_NAME)
41if torch.cuda.is_available():
42device = "cuda"
43else:
44device = "cpu"
45model.to(device)
46logger.info(f"prompt_storygpt is set to run on {device}")
47logger.info("prompt_storygpt is ready")
48except Exception as e:
49sentry_sdk.capture_exception(e)
50logger.exception(e)
51raise e
52
53app = Flask(__name__)
54logging.getLogger("werkzeug").setLevel("WARNING")
55
56
57def generate_part(texts, max_len, temp, num_sents, first):
58if not first:
59texts = [text + " At the end," for text in texts]
60encoding = tokenizer(texts, padding=True, return_tensors="pt").to(device)
61with torch.no_grad():
62generated_ids = model.generate(
63**encoding,
64max_length=max_len,
65length_penalty=-100.0,
66temperature=temp,
67do_sample=True,
68)
69generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
70
71return_texts = []
72for text in generated_texts:
73text = pattern.sub("", text) # delete everything in ()
74text = text.replace(" .", ".").replace("..", ".").replace("..", ".")
75sents = sent_tokenize(text)
76text = " ".join(sents[:num_sents])
77if text[-1] not in ".!?":
78if text[-1] in punctuation:
79text = text[:-1]
80text += "."
81return_texts.append(text)
82return return_texts
83
84
85def fill_mask(masked_phrases):
86batch = bart_tok(masked_phrases, return_tensors="pt")
87generated_ids = bart_model.generate(batch["input_ids"])
88filled = bart_tok.batch_decode(generated_ids, skip_special_tokens=True)
89logger.info(f"Filled masks: {filled}")
90return filled
91
92
93def generate_first_part(context):
94"""
95Parameters
96context: List[str]
97a list consisting of nouns chosen from spacy_nounphrases annotator
98Returns
99final_text: List[str]
100generated stories
101"""
102nouns = context
103logger.info(f"Topic in StoryGPT service: {nouns}")
104masked_phrases = []
105for noun in nouns:
106masked_phrases.append(f"Let me share a story about {noun}. I <mask> {noun}")
107filled = fill_mask(masked_phrases)
108
109st_time = time.time()
110final_texts = generate_part(filled, 50, 0.8, 4, first=True)
111total_time = time.time() - st_time
112logger.info(f"Time for first part generation: {total_time:.3f}s")
113final_texts = ["Ok, " + text + continue_phrase for text in final_texts]
114logger.info(f"First parts generated: {final_texts}")
115return final_texts
116
117
118def generate_second_part(context):
119first_texts = context
120logger.info(f"Received first part: {first_texts}")
121first_texts = [text.replace(continue_phrase, "") for text in first_texts]
122st_time = time.time()
123final_texts = generate_part(first_texts, 100, 0.8, 5, first=False)
124final_texts = [final_texts[i].replace(first_texts[i], "") for i in range(len(final_texts))]
125logger.info(f"Generated: {final_texts}")
126total_time = time.time() - st_time
127logger.info(f"Time for generation: {total_time:.3f}s")
128return final_texts
129
130
131def generate_response(context):
132texts, first_part = context
133if first_part:
134replies = generate_first_part(texts) # text is a list of nouns
135else:
136replies = generate_second_part(texts) # text is a list of first part texts
137return replies
138
139
140@app.route("/respond", methods=["POST"])
141def respond():
142st_time = time.time()
143contexts = request.json.get("utterances_histories", [])
144
145try:
146tmp_responses = generate_response(contexts)
147responses = []
148confidences = []
149for response in tmp_responses:
150if len(response) > 3:
151# drop too short responses
152responses += [response]
153confidences += [DEFAULT_CONFIDENCE]
154else:
155responses += [""]
156confidences += [ZERO_CONFIDENCE]
157except Exception as exc:
158logger.exception(exc)
159sentry_sdk.capture_exception(exc)
160responses = [""] * len(contexts)
161confidences = [ZERO_CONFIDENCE] * len(contexts)
162
163total_time = time.time() - st_time
164logger.info(f"Prompt storyGPT exec time: {total_time:.3f}s")
165return jsonify(list(zip(responses, confidences)))
166