dream
221 строка · 7.6 Кб
1import logging
2import os
3import random
4import pickle
5import time
6import json
7import difflib
8import traceback
9import re
10
11import tensorflow_hub as tfhub
12import tensorflow as tf
13import tensorflow_text
14import numpy as np
15from flask import Flask, request, jsonify
16from flasgger import Swagger, swag_from
17import sentry_sdk
18
19tensorflow_text.__name__
20
21SENTRY_DSN = os.getenv("SENTRY_DSN")
22SEED = 31415
23MODEL_PATH = os.getenv("MODEL_PATH")
24DATABASE_PATH = os.getenv("DATABASE_PATH")
25CONFIDENCE_PATH = os.getenv("CONFIDENCE_PATH")
26SOFTMAX_TEMPERATURE = float(os.getenv("SOFTMAX_TEMPERATURE", 0.08))
27CONFIDENCE_DECAY = float(os.getenv("CONVERT_CONFIDENCE_DECAY", 0.9))
28NUM_SAMPLE = int(os.getenv("NUM_SAMPLE", 3))
29
30
31sentry_sdk.init(SENTRY_DSN)
32logging.basicConfig(
33level=logging.INFO,
34format="%(asctime)s %(module)s %(lineno)d %(levelname)s : %(message)s",
35handlers=[logging.StreamHandler()],
36)
37
38logger = logging.getLogger(__name__)
39app = Flask(__name__)
40swagger = Swagger(app)
41
42random.seed(SEED)
43
44sess = tf.InteractiveSession(graph=tf.Graph())
45
46module = tfhub.Module(MODEL_PATH)
47response_encodings, responses = pickle.load(open(DATABASE_PATH, "rb"))
48confidences = np.load(CONFIDENCE_PATH)
49
50
51spaces_pat = re.compile(r"\s+")
52special_symb_pat = re.compile(r"[^A-Za-z0-9 ]")
53
54
55def clear_text(text):
56text = special_symb_pat.sub("", spaces_pat.sub(" ", text.lower().replace("\n", " "))).strip()
57text = text.replace("\u2019", "'")
58return text
59
60
61banned_responses = json.load(open("./banned_responses.json"))
62banned_responses = [clear_text(utter) for utter in banned_responses]
63banned_phrases = json.load(open("./banned_phrases.json"))
64banned_words = json.load(open("./banned_words.json"))
65banned_words_for_questions = json.load(open("./banned_words_for_questions.json"))
66
67text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
68extra_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
69
70# The encode_context signature now also takes the extra context.
71context_encoding_tensor = module(
72{"context": text_placeholder, "extra_context": extra_text_placeholder}, signature="encode_context"
73)
74
75sess.run(tf.tables_initializer())
76sess.run(tf.global_variables_initializer())
77
78
79def encode_context(dialogue_history):
80"""Encode the dialogue context to the response ranking vector space.
81
82Args:
83dialogue_history: a list of strings, the dialogue history, in
84chronological order.
85"""
86
87# The context is the most recent message in the history.
88context = dialogue_history[-1]
89
90extra_context = list(dialogue_history[:-1])
91extra_context.reverse()
92extra_context_feature = " ".join(extra_context)
93
94return sess.run(
95context_encoding_tensor,
96feed_dict={text_placeholder: [context], extra_text_placeholder: [extra_context_feature]},
97)[0]
98
99
100def approximate_confidence(confidence, approximate_confidence_is_enabled=True):
101if approximate_confidence_is_enabled:
102return 0.85 * (confidences <= confidence).sum() / len(confidences)
103else:
104return float(confidence)
105
106
107def get_BOW(sentence):
108filtered_sentence = re.sub("[^A-Za-z0-9]+", " ", sentence).split()
109filtered_sentence = [token for token in filtered_sentence if len(token) > 2]
110return set(filtered_sentence)
111
112
113unanswered_utters = ["let's talk about", "what else can you do?", "let's talk about books"]
114unanswered_utters = [get_BOW(utter) for utter in unanswered_utters]
115
116
117def is_unanswerable_utters(history):
118last_utter = get_BOW(history[-1])
119for utter in unanswered_utters:
120if len(last_utter & utter) / len(last_utter | utter) > 0.9:
121return True
122
123
124def softmax(x, t):
125e_x = np.exp((x - np.max(x)) / t)
126return e_x / e_x.sum(axis=0)
127
128
129def exponential_decay(init_value, factor, num_steps):
130return init_value * factor**num_steps
131
132
133def sample_candidates(candidates, choice_num=1, replace=False, softmax_temperature=1):
134choice_num = min(choice_num, len(candidates))
135confidences = [cand[1] for cand in candidates]
136choice_probs = softmax(confidences, softmax_temperature)
137one_dim_candidates = np.array(candidates)
138one_dim_indices = np.arange(len(one_dim_candidates))
139sampled_one_dim_indices = np.random.choice(one_dim_indices, choice_num, replace=replace, p=choice_probs)
140sampled_candidates = one_dim_candidates[sampled_one_dim_indices]
141return sampled_candidates.tolist()
142
143
144def inference(utterances_histories, num_ongoing_utt, approximate_confidence_is_enabled=True):
145context_encoding = encode_context(utterances_histories)
146scores = context_encoding.dot(response_encodings.T)
147indices = np.argsort(scores)[::-1][:10]
148filtered_indices = []
149for ind in indices:
150cand = responses[ind]
151if not [
152None
153for f_utter in banned_responses
154if difflib.SequenceMatcher(None, f_utter.split(), clear_text(cand).split()).ratio() > 0.9
155]:
156filtered_indices.append(ind)
157
158if is_unanswerable_utters(utterances_histories):
159return "", 0.0
160
161clear_utterances_histories = [clear_text(utt).split() for utt in utterances_histories[::-1][1::2][::-1]]
162
163for ind in reversed(filtered_indices):
164cand = clear_text(responses[ind]).split()
165raw_cand = responses[ind].lower()
166# hello ban
167hello_flag = any([j in cand[:3] for j in ["hi", "hello"]])
168# banned_words ban
169banned_words_flag = any([j in cand for j in banned_words])
170banned_words_for_questions_flag = any([(j in cand and "?" in raw_cand) for j in banned_words_for_questions])
171
172# banned_phrases ban
173banned_phrases_flag = any([j in raw_cand for j in banned_phrases])
174
175# ban long words
176long_words_flag = any([len(j) > 30 for j in cand])
177
178if hello_flag or banned_words_flag or banned_words_for_questions_flag or banned_phrases_flag or long_words_flag:
179filtered_indices.remove(ind)
180continue
181for utterance in clear_utterances_histories:
182if difflib.SequenceMatcher(None, utterance, cand).ratio() > 0.6:
183filtered_indices.remove(ind)
184break
185
186if len(filtered_indices) > 0:
187candidates = [
188(responses[ind], approximate_confidence(scores[ind], approximate_confidence_is_enabled))
189for ind in filtered_indices
190]
191try:
192selected_candidates = sample_candidates(
193candidates, choice_num=NUM_SAMPLE, softmax_temperature=SOFTMAX_TEMPERATURE
194)
195answers = [cand[0] for cand in selected_candidates]
196confidences = [
197exponential_decay(float(cand[1]), CONFIDENCE_DECAY, num_ongoing_utt) for cand in selected_candidates
198]
199return answers, confidences
200except Exception:
201logger.error(traceback.format_exc())
202candidate = (
203candidates[0][0],
204exponential_decay(float(candidates[0][1]), CONFIDENCE_DECAY, num_ongoing_utt),
205)
206return candidate
207else:
208return "", 0.0
209
210
211@app.route("/convert_reddit", methods=["POST"])
212@swag_from("chitchat_endpoint.yml")
213def convert_chitchat_model():
214st_time = time.time()
215utterances_histories = request.json["utterances_histories"]
216approximate_confidence_is_enabled = request.json.get("approximate_confidence_is_enabled", True)
217num_ongoing_utt = request.json.get("num_ongoing_utt", [0])
218response = [inference(hist, num_ongoing_utt[0], approximate_confidence_is_enabled) for hist in utterances_histories]
219total_time = time.time() - st_time
220logger.warning(f"convert_reddit exec time: {total_time:.3f}s")
221return jsonify(response)
222