dream
207 строк · 6.5 Кб
1import logging2import time3import os4import random5
6from transformers import AutoTokenizer, AutoModelForCausalLM7import torch8from flask import Flask, request, jsonify9from healthcheck import HealthCheck10import sentry_sdk11from sentry_sdk.integrations.flask import FlaskIntegration12
13sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])14
15
16logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)17logger = logging.getLogger(__name__)18
19PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get(20"PRETRAINED_MODEL_NAME_OR_PATH", "DeepPavlov/rudialogpt3_medium_based_on_gpt2_v2"21)
22logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")23
24cuda = torch.cuda.is_available()25if cuda:26torch.cuda.set_device(0)27device = "cuda"28else:29device = "cpu"30
31try:32tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)33model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH).to(device)34model.eval()35
36logger.info("dialogpt model is ready")37except Exception as e:38sentry_sdk.capture_exception(e)39logger.exception(e)40raise e41
42logger.info(f"dialogpt is set to run on {device}")43
44SHORT_UTTERANCE_PROBA = 0.745MAX_HISTORY_DEPTH = os.environ.get("MAX_HISTORY_DEPTH")46MAX_HISTORY_DEPTH = int(MAX_HISTORY_DEPTH) if MAX_HISTORY_DEPTH else MAX_HISTORY_DEPTH47
48params_default = {49"max_length": 128,50"no_repeat_ngram_size": 3,51"do_sample": True,52"top_k": 20,53"top_p": 0.9,54"temperature": 0.7,55"num_return_sequences": 3,56"device": device,57"is_always_use_length": True,58}
59
60
61def inputs_by_length(input_: dict, length_rep=None):62if length_rep is None:63length_rep = len(tokenizer.encode(input_["text"]))64if params_default["is_always_use_length"]:65if length_rep <= 15:66length_param = "1"67elif length_rep <= 50:68length_param = "2"69elif length_rep <= 256:70length_param = "3"71else:72length_param = "-"73else:74length_param = "-"75return f"|{input_['speaker']}|{length_param}|{input_['text']}"76
77
78def format_dialogue_with_target(context, context_lengths, context_depth=3, encode=False, tokenizer=None):79"""80THE LAST UTTERANCE IN THE CONTEXT IS TARGET BOT'S UTTERANCE
81
82context: List(dict)
83context = [
84{"text": "speaker": "human"},
85{"text": "hi there", "speaker": "bot"},
86{"text": "how are you", "speaker": "human"},
87{"text": "great how are u", "speaker": "bot"},
88]
89OR
90context = [
91"hi",
92"hi there",
93"how are you",
94"great how are u"
95]
96"""
97if len(context) > 0 and isinstance(context[0], str):98context_len = len(context)99# the last uttr is from BOT100inputs = [{"text": uttr, "speaker": (context_len - uttr_id) % 2} for uttr_id, uttr in enumerate(context)]101inputs = inputs[-context_depth:]102else:103inputs = [{"text": uttr["text"], "speaker": 1 if uttr["speaker"] == "bot" else 0} for uttr in context]104inputs = inputs[-context_depth:]105
106inputs_text = "".join([inputs_by_length(input_, inp_len) for input_, inp_len in zip(inputs, context_lengths)])107
108if encode:109# if encode, return encoded context110inputs_token_ids = tokenizer.encode(inputs_text, return_tensors="pt")111return inputs_token_ids112
113return inputs_text114
115
116def format_dialogue_for_inference(context, context_depth=4, encode=False, tokenizer=None):117"""118THE LAST UTTERANCE IN THE CONTEXT IS TARGET HUMAN'S UTTERANCE
119
120context: List(dict)
121context = [
122{"text": "speaker": "human"},
123{"text": "hi there", "speaker": "bot"},
124{"text": "how are you", "speaker": "human"},
125]
126OR
127context = [
128"hi",
129"hi there",
130"how are you",
131]
132"""
133if len(context) > 0 and isinstance(context[0], str):134context_len = len(context)135# the last uttr is from HUMAN136inputs = [{"text": uttr, "speaker": (context_len - uttr_id - 1) % 2} for uttr_id, uttr in enumerate(context)]137inputs = inputs[-context_depth:]138else:139inputs = [{"text": uttr["text"], "speaker": 1 if uttr["speaker"] == "bot" else 0} for uttr in context]140inputs = inputs[-context_depth:]141
142inputs_text = "".join([inputs_by_length(input_) for input_ in inputs])143length = "2" if random.uniform(0, 1) > SHORT_UTTERANCE_PROBA else "1"144inputs_text += f"|1|{length}|"145
146if encode:147# if encode, return encoded context148inputs_token_ids = tokenizer.encode(inputs_text, return_tensors="pt")149return inputs_token_ids150
151return inputs_text152
153
154app = Flask(__name__)155health = HealthCheck(app, "/healthcheck")156logging.getLogger("werkzeug").setLevel("WARNING")157
158
159@app.route("/ping", methods=["POST"])160def ping():161return "pong"162
163
164def generate(context, num_return_sequences, context_depth):165bot_input_ids = format_dialogue_for_inference(166context, context_depth=context_depth, encode=True, tokenizer=tokenizer167)168bot_input_ids = bot_input_ids.to(device)169params_default["num_return_sequences"] = num_return_sequences170
171chat_history_ids = model.generate(bot_input_ids, pad_token_id=tokenizer.eos_token_id, **params_default)172resp_tokens = chat_history_ids[:, bot_input_ids.shape[-1] :]173outputs = [tokenizer.decode(x, skip_special_tokens=True) for x in resp_tokens]174outputs = [x.split("|")[0] for x in outputs]175
176return outputs177
178
179@app.route("/respond", methods=["POST"])180def respond():181st_time = time.time()182
183dialog_contexts = request.json.get("dialog_contexts", [])184num_return_sequences = request.json.get("num_return_sequences", 3)185
186try:187batch_generated_responses = []188for context in dialog_contexts:189# context is a list of dicts, each dict contains text and speaker label190# context = [{"text": "utterance text", "speaker": "human"}, ...]191logger.info(f"dialogpt inputs: {context[-MAX_HISTORY_DEPTH:]}")192
193hypotheses = generate(194context[-MAX_HISTORY_DEPTH:], num_return_sequences=num_return_sequences, context_depth=MAX_HISTORY_DEPTH195)196logger.info(f"dialogpt hypotheses: {hypotheses}")197batch_generated_responses.append(hypotheses)198
199except Exception as exc:200logger.exception(exc)201sentry_sdk.capture_exception(exc)202batch_generated_responses = [[]] * len(dialog_contexts)203
204total_time = time.time() - st_time205logger.info(f"dialogpt exec time: {total_time:.3f}s")206
207return jsonify(batch_generated_responses)208