dream
115 строк · 4.3 Кб
1import logging2import time3import os4import re5import json6import sentry_sdk7import torch8from flask import Flask, request, jsonify9from sentry_sdk.integrations.flask import FlaskIntegration10from transformers import GPT2LMHeadModel, GPT2Tokenizer11from itertools import cycle12
13sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])14
15
16logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)17logger = logging.getLogger(__name__)18
19PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")20CONFIG_NAME = os.environ.get("CONFIG_NAME")21N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE"))22MAX_HISTORY_DEPTH = int(os.environ.get("MAX_HISTORY_DEPTH"))23logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")24
25DEFAULT_CONFIDENCE = 0.926ZERO_CONFIDENCE = 0.027
28with open(CONFIG_NAME, "r") as f:29generation_params = json.load(f)30generation_params["num_return_sequences"] = N_HYPOTHESES_TO_GENERATE31
32NICK_COMPILED = re.compile(r"\@[^\s]+\b")33SPECIFIC_SYMBOLS = r"[#$%&()*+/;<=>@\^_`{|}~\[\]�]"34SPECIFIC_WORDS_COMPILED = re.compile(rf"{SPECIFIC_SYMBOLS}[a-zA-Z]+{SPECIFIC_SYMBOLS}")35URLS_COMPILED = re.compile(r"(http[^\s]+\b|<a href[^>]+>)")36ANYTHING_EXCEPT_OF_LETTERS_SPACE_AND_PUNCT_COMPILED = re.compile(SPECIFIC_SYMBOLS)37NEW_UTTERANCE = ""38
39try:40tokenizer = GPT2Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)41model = GPT2LMHeadModel.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)42if torch.cuda.is_available():43model.to("cuda")44logger.info("gpt2-generator is set to run on cuda")45
46logger.info("gpt2-generator is ready")47except Exception as e:48sentry_sdk.capture_exception(e)49logger.exception(e)50raise e51
52app = Flask(__name__)53logging.getLogger("werkzeug").setLevel("WARNING")54
55
56def generate_response(context, model, tokenizer):57text = "\n".join(58list(map(lambda x: NEW_UTTERANCE.join(x), zip(cycle(["", ""]), context[-MAX_HISTORY_DEPTH:] + [""])))59)60bot_input_ids = tokenizer.encode(text, return_tensors="pt")61with torch.no_grad():62if torch.cuda.is_available():63bot_input_ids = bot_input_ids.to("cuda")64
65chat_history_ids = model.generate(bot_input_ids, pad_token_id=tokenizer.eos_token_id, **generation_params)66if torch.cuda.is_available():67chat_history_ids = chat_history_ids.cpu()68
69outputs = [70tokenizer.decode(x, skip_special_tokens=True)[len(text) :].lstrip().split("\n")[0] for x in chat_history_ids71]72outputs = [re.sub(NICK_COMPILED, "", response) for response in outputs]73outputs = [re.sub(URLS_COMPILED, "", response) for response in outputs]74outputs = [re.sub(SPECIFIC_WORDS_COMPILED, "", response) for response in outputs]75outputs = [ANYTHING_EXCEPT_OF_LETTERS_SPACE_AND_PUNCT_COMPILED.sub("", response).strip() for response in outputs]76return outputs77
78
79@app.route("/respond", methods=["POST"])80def respond():81st_time = time.time()82contexts = request.json.get("utterances_histories", [])83if len(contexts) == 0:84contexts = request.json.get("dialog_contexts", [])85
86try:87responses = []88confidences = []89for context in contexts:90curr_responses = []91curr_confidences = []92preds = generate_response(context, model, tokenizer)93for response in preds:94response = re.sub(NICK_COMPILED, "", response).strip()95
96if len(response) > 3:97# drop too short responses98curr_responses += [response]99curr_confidences += [DEFAULT_CONFIDENCE]100else:101curr_responses += [""]102curr_confidences += [ZERO_CONFIDENCE]103logger.info(f"gpt2-generator for context: `{context}`\n returns: {curr_responses}")104responses += [curr_responses]105confidences += [curr_confidences]106
107except Exception as exc:108logger.exception(exc)109sentry_sdk.capture_exception(exc)110responses = [[""]] * len(contexts)111confidences = [[ZERO_CONFIDENCE]] * len(contexts)112
113total_time = time.time() - st_time114logger.info(f"gpt2-generator exec time: {total_time:.3f}s")115return jsonify(list(zip(responses, confidences)))116