dream

Форк
0
221 строка · 7.6 Кб
1
import logging
2
import os
3
import random
4
import pickle
5
import time
6
import json
7
import difflib
8
import traceback
9
import re
10

11
import tensorflow_hub as tfhub
12
import tensorflow as tf
13
import tensorflow_text
14
import numpy as np
15
from flask import Flask, request, jsonify
16
from flasgger import Swagger, swag_from
17
import sentry_sdk
18

19
tensorflow_text.__name__
20

21
SENTRY_DSN = os.getenv("SENTRY_DSN")
22
SEED = 31415
23
MODEL_PATH = os.getenv("MODEL_PATH")
24
DATABASE_PATH = os.getenv("DATABASE_PATH")
25
CONFIDENCE_PATH = os.getenv("CONFIDENCE_PATH")
26
SOFTMAX_TEMPERATURE = float(os.getenv("SOFTMAX_TEMPERATURE", 0.08))
27
CONFIDENCE_DECAY = float(os.getenv("CONVERT_CONFIDENCE_DECAY", 0.9))
28
NUM_SAMPLE = int(os.getenv("NUM_SAMPLE", 3))
29

30

31
sentry_sdk.init(SENTRY_DSN)
32
logging.basicConfig(
33
    level=logging.INFO,
34
    format="%(asctime)s %(module)s %(lineno)d %(levelname)s : %(message)s",
35
    handlers=[logging.StreamHandler()],
36
)
37

38
logger = logging.getLogger(__name__)
39
app = Flask(__name__)
40
swagger = Swagger(app)
41

42
random.seed(SEED)
43

44
sess = tf.InteractiveSession(graph=tf.Graph())
45

46
module = tfhub.Module(MODEL_PATH)
47
response_encodings, responses = pickle.load(open(DATABASE_PATH, "rb"))
48
confidences = np.load(CONFIDENCE_PATH)
49

50

51
spaces_pat = re.compile(r"\s+")
52
special_symb_pat = re.compile(r"[^A-Za-z0-9 ]")
53

54

55
def clear_text(text):
56
    text = special_symb_pat.sub("", spaces_pat.sub(" ", text.lower().replace("\n", " "))).strip()
57
    text = text.replace("\u2019", "'")
58
    return text
59

60

61
banned_responses = json.load(open("./banned_responses.json"))
62
banned_responses = [clear_text(utter) for utter in banned_responses]
63
banned_phrases = json.load(open("./banned_phrases.json"))
64
banned_words = json.load(open("./banned_words.json"))
65
banned_words_for_questions = json.load(open("./banned_words_for_questions.json"))
66

67
text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
68
extra_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
69

70
# The encode_context signature now also takes the extra context.
71
context_encoding_tensor = module(
72
    {"context": text_placeholder, "extra_context": extra_text_placeholder}, signature="encode_context"
73
)
74

75
sess.run(tf.tables_initializer())
76
sess.run(tf.global_variables_initializer())
77

78

79
def encode_context(dialogue_history):
80
    """Encode the dialogue context to the response ranking vector space.
81

82
    Args:
83
        dialogue_history: a list of strings, the dialogue history, in
84
            chronological order.
85
    """
86

87
    # The context is the most recent message in the history.
88
    context = dialogue_history[-1]
89

90
    extra_context = list(dialogue_history[:-1])
91
    extra_context.reverse()
92
    extra_context_feature = " ".join(extra_context)
93

94
    return sess.run(
95
        context_encoding_tensor,
96
        feed_dict={text_placeholder: [context], extra_text_placeholder: [extra_context_feature]},
97
    )[0]
98

99

100
def approximate_confidence(confidence, approximate_confidence_is_enabled=True):
101
    if approximate_confidence_is_enabled:
102
        return 0.85 * (confidences <= confidence).sum() / len(confidences)
103
    else:
104
        return float(confidence)
105

106

107
def get_BOW(sentence):
108
    filtered_sentence = re.sub("[^A-Za-z0-9]+", " ", sentence).split()
109
    filtered_sentence = [token for token in filtered_sentence if len(token) > 2]
110
    return set(filtered_sentence)
111

112

113
unanswered_utters = ["let's talk about", "what else can you do?", "let's talk about books"]
114
unanswered_utters = [get_BOW(utter) for utter in unanswered_utters]
115

116

117
def is_unanswerable_utters(history):
118
    last_utter = get_BOW(history[-1])
119
    for utter in unanswered_utters:
120
        if len(last_utter & utter) / len(last_utter | utter) > 0.9:
121
            return True
122

123

124
def softmax(x, t):
125
    e_x = np.exp((x - np.max(x)) / t)
126
    return e_x / e_x.sum(axis=0)
127

128

129
def exponential_decay(init_value, factor, num_steps):
130
    return init_value * factor**num_steps
131

132

133
def sample_candidates(candidates, choice_num=1, replace=False, softmax_temperature=1):
134
    choice_num = min(choice_num, len(candidates))
135
    confidences = [cand[1] for cand in candidates]
136
    choice_probs = softmax(confidences, softmax_temperature)
137
    one_dim_candidates = np.array(candidates)
138
    one_dim_indices = np.arange(len(one_dim_candidates))
139
    sampled_one_dim_indices = np.random.choice(one_dim_indices, choice_num, replace=replace, p=choice_probs)
140
    sampled_candidates = one_dim_candidates[sampled_one_dim_indices]
141
    return sampled_candidates.tolist()
142

143

144
def inference(utterances_histories, num_ongoing_utt, approximate_confidence_is_enabled=True):
145
    context_encoding = encode_context(utterances_histories)
146
    scores = context_encoding.dot(response_encodings.T)
147
    indices = np.argsort(scores)[::-1][:10]
148
    filtered_indices = []
149
    for ind in indices:
150
        cand = responses[ind]
151
        if not [
152
            None
153
            for f_utter in banned_responses
154
            if difflib.SequenceMatcher(None, f_utter.split(), clear_text(cand).split()).ratio() > 0.9
155
        ]:
156
            filtered_indices.append(ind)
157

158
    if is_unanswerable_utters(utterances_histories):
159
        return "", 0.0
160

161
    clear_utterances_histories = [clear_text(utt).split() for utt in utterances_histories[::-1][1::2][::-1]]
162

163
    for ind in reversed(filtered_indices):
164
        cand = clear_text(responses[ind]).split()
165
        raw_cand = responses[ind].lower()
166
        # hello ban
167
        hello_flag = any([j in cand[:3] for j in ["hi", "hello"]])
168
        # banned_words ban
169
        banned_words_flag = any([j in cand for j in banned_words])
170
        banned_words_for_questions_flag = any([(j in cand and "?" in raw_cand) for j in banned_words_for_questions])
171

172
        # banned_phrases ban
173
        banned_phrases_flag = any([j in raw_cand for j in banned_phrases])
174

175
        # ban long words
176
        long_words_flag = any([len(j) > 30 for j in cand])
177

178
        if hello_flag or banned_words_flag or banned_words_for_questions_flag or banned_phrases_flag or long_words_flag:
179
            filtered_indices.remove(ind)
180
            continue
181
        for utterance in clear_utterances_histories:
182
            if difflib.SequenceMatcher(None, utterance, cand).ratio() > 0.6:
183
                filtered_indices.remove(ind)
184
                break
185

186
    if len(filtered_indices) > 0:
187
        candidates = [
188
            (responses[ind], approximate_confidence(scores[ind], approximate_confidence_is_enabled))
189
            for ind in filtered_indices
190
        ]
191
        try:
192
            selected_candidates = sample_candidates(
193
                candidates, choice_num=NUM_SAMPLE, softmax_temperature=SOFTMAX_TEMPERATURE
194
            )
195
            answers = [cand[0] for cand in selected_candidates]
196
            confidences = [
197
                exponential_decay(float(cand[1]), CONFIDENCE_DECAY, num_ongoing_utt) for cand in selected_candidates
198
            ]
199
            return answers, confidences
200
        except Exception:
201
            logger.error(traceback.format_exc())
202
            candidate = (
203
                candidates[0][0],
204
                exponential_decay(float(candidates[0][1]), CONFIDENCE_DECAY, num_ongoing_utt),
205
            )
206
            return candidate
207
    else:
208
        return "", 0.0
209

210

211
@app.route("/convert_reddit", methods=["POST"])
212
@swag_from("chitchat_endpoint.yml")
213
def convert_chitchat_model():
214
    st_time = time.time()
215
    utterances_histories = request.json["utterances_histories"]
216
    approximate_confidence_is_enabled = request.json.get("approximate_confidence_is_enabled", True)
217
    num_ongoing_utt = request.json.get("num_ongoing_utt", [0])
218
    response = [inference(hist, num_ongoing_utt[0], approximate_confidence_is_enabled) for hist in utterances_histories]
219
    total_time = time.time() - st_time
220
    logger.warning(f"convert_reddit exec time: {total_time:.3f}s")
221
    return jsonify(response)
222

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.