dream
81 строка · 2.3 Кб
1# tensorflow==1.14.0
2# tensorflow_text==0.1.0
3# tensorflow-hub==0.7.0
4# wget http://files.deeppavlov.ai/alexaprize_data/convert_reddit_v2.8.tar.gz
5# tar xzfv .....
6# MODEL_PATH=........../convert_data/convert
7import logging
8import os
9
10import numpy as np
11import sentry_sdk
12import tensorflow_hub as tfhub
13import tensorflow as tf
14import tensorflow_text
15
16
17sentry_sdk.init(os.getenv("SENTRY_DSN"))
18
19logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
20logger = logging.getLogger(__name__)
21
22tensorflow_text.__name__
23
24MODEL_PATH = "/convert/convert"
25
26sess = tf.InteractiveSession(graph=tf.Graph())
27
28module = tfhub.Module(MODEL_PATH)
29
30
31text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
32extra_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
33
34# The encode_context signature now also takes the extra context.
35context_encoding_tensor = module(
36{"context": text_placeholder, "extra_context": extra_text_placeholder}, signature="encode_context"
37)
38
39
40responce_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
41
42response_encoding_tensor = module(responce_text_placeholder, signature="encode_response")
43
44sess.run(tf.tables_initializer())
45sess.run(tf.global_variables_initializer())
46
47
48def encode_contexts(dialog_history_batch):
49"""Encode the dialog context to the response ranking vector space.
50
51Args:
52dialog_history: a list of strings, the dialog history, in
53chronological order.
54"""
55
56# The context is the most recent message in the history.
57contexts = []
58extra_context_features = []
59
60for dialog_history in dialog_history_batch:
61contexts += [dialog_history[-1]]
62
63extra_context = list(dialog_history[:-1])
64extra_context.reverse()
65extra_context_features += [" ".join(extra_context)]
66
67return sess.run(
68context_encoding_tensor,
69feed_dict={text_placeholder: contexts, extra_text_placeholder: extra_context_features},
70)
71
72
73def encode_responses(texts):
74return sess.run(response_encoding_tensor, feed_dict={responce_text_placeholder: texts})
75
76
77def get_convert_score(contexts, responses):
78context_encodings = encode_contexts(contexts)
79response_encodings = encode_responses(responses) # 79, 512
80res = np.multiply(context_encodings, response_encodings)
81return np.sum(res, axis=1).reshape(-1, 1)
82