dream
158 строк · 5.4 Кб
1import logging2import os3import time4
5import sentry_sdk6import torch7from flask import Flask, request, jsonify8from sentry_sdk.integrations.flask import FlaskIntegration9from transformers import AutoModelForCausalLM, AutoTokenizer10
11from common.utils import get_intents12
13
14sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])15logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)16logger = logging.getLogger(__name__)17
18logging.getLogger("werkzeug").setLevel("INFO")19app = Flask(__name__)20
21PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")22MAX_PERSONA_SENTENCES = int(os.environ.get("MAX_PERSONA_SENTENCES"))23
24SUPER_CONFIDENCE = 1.025DEFAULT_CONFIDENCE = 0.926
27SPECIAL_TOKENS = {28"<sp_1>": "<sp_1>",29"</sp_1>": "</sp_1>",30"<sp_2>": "<sp_2>",31"</sp_2>": "</sp_2>",32"<persona>": "<persona>",33"</persona>": "</persona>",34}
35
36try:37device = torch.device("cuda" if torch.cuda.is_available() else "cpu")38tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)39model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)40
41if torch.cuda.is_available():42model.to("cuda")43logger.info("dialogpt_persona_based is set to run on cuda")44
45logger.info("dialogpt_persona_based is ready")46except Exception as e:47sentry_sdk.capture_exception(e)48logger.exception(e)49raise e50
51
52def generate_response(persona: dict = None, model=None, tokenizer=None, utterances_histories=None):53"""generates the next replica of the bot based on a short persona consisting of several sentences.54
55Args:
56persona (List[List[str], float]): Top sentences similar to the last replica. Defaults to None.
57model (AutoModelForCausalLM): gpt model. Defaults to None.
58tokenizer (AutoTokenizer): gpt tokenizer. Defaults to None.
59utterances_histories (List[List[str]]): dialog history. Defaults to None.
60
61Returns:
62str: next utterance
63"""
64vocab_tokens = tokenizer.get_added_vocab()65
66max_likelihood_sentences = persona["persona"]67max_likelihood_sentences = max_likelihood_sentences[:MAX_PERSONA_SENTENCES]68max_likelihood_sentences = " ".join(max_likelihood_sentences)69max_likelihood_sentences = f"{SPECIAL_TOKENS['<persona>']}{max_likelihood_sentences}{SPECIAL_TOKENS['</persona>']}"70persona_ids = tokenizer.encode(max_likelihood_sentences, return_tensors="pt")71persona_ids = persona_ids.to(device)72
73utterances_histories = utterances_histories[0]74history_chat = "".join(75list(76reversed(77[f"<sp_{(i)%2+1}>{item}</sp_{(i)%2+1}>" for i, item in enumerate(reversed(utterances_histories[-1:]))]78)79)80)81history_chat += "<sp_2>"82
83history_ids = tokenizer.encode(history_chat, return_tensors="pt")84history_ids = history_ids.to(device)85
86bot_input_ids = torch.cat([persona_ids, history_ids], dim=-1)87
88model_response = model.generate(89bot_input_ids,90max_length=150,91pad_token_id=tokenizer.eos_token_id,92do_sample=True,93temperature=0.95,94top_k=50,95top_p=0.95,96)97
98model_response = model_response.to(device)99model_response_list = list(model_response[0])100bot_response_decode = None101if vocab_tokens["</sp_2>"] in model_response_list:102end_speaker_index = model_response_list.index(vocab_tokens["</sp_2>"])103model_response = model_response[:, : end_speaker_index + 1]104
105chat_history_ids = model_response106bot_response_decode = tokenizer.decode(107chat_history_ids[0][len(bot_input_ids[0]) - 1 :], skip_special_tokens=True108)109else:110bot_response_decode = ""111
112return bot_response_decode113
114
115@app.route("/respond", methods=["POST"])116def respond():117start_time = time.time()118responses = []119confidences = []120
121last_annotated_utterances_batch = request.json["last_annotated_utterances"]122utterances_histories = request.json["utterances_histories"]123try:124for utt_pos in range(len(last_annotated_utterances_batch)):125persona = (126last_annotated_utterances_batch[utt_pos].get("annotations", {}).get("relative_persona_extractor", [])127)128response = ""129try:130response = generate_response(131model=model,132tokenizer=tokenizer,133persona=persona,134utterances_histories=utterances_histories,135)136except Exception as e:137logger.exception(e)138response = ""139
140if "open_question_personal" in get_intents(last_annotated_utterances_batch[utt_pos]):141logger.info("open_question_personal")142responses.append([response])143confidences.append([SUPER_CONFIDENCE])144else:145logger.info("NOT open_question_personal")146responses.append([response])147confidences.append([DEFAULT_CONFIDENCE])148
149except Exception as exc:150logger.exception(exc)151sentry_sdk.capture_exception(exc)152responses = [""] * len(last_annotated_utterances_batch)153confidences = [0.0] * len(last_annotated_utterances_batch)154
155total_time = time.time() - start_time156logger.info(f"dialogpt_persona_based exec time: {total_time:.3f}s")157
158return jsonify(list(zip(responses, confidences)))159