dream

Форк
0
81 строка · 2.3 Кб
1
# tensorflow==1.14.0
2
# tensorflow_text==0.1.0
3
# tensorflow-hub==0.7.0
4
# wget http://files.deeppavlov.ai/alexaprize_data/convert_reddit_v2.8.tar.gz
5
# tar xzfv .....
6
# MODEL_PATH=........../convert_data/convert
7
import logging
8
import os
9

10
import numpy as np
11
import sentry_sdk
12
import tensorflow_hub as tfhub
13
import tensorflow as tf
14
import tensorflow_text
15

16

17
sentry_sdk.init(os.getenv("SENTRY_DSN"))
18

19
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
20
logger = logging.getLogger(__name__)
21

22
tensorflow_text.__name__
23

24
MODEL_PATH = "/convert/convert"
25

26
sess = tf.InteractiveSession(graph=tf.Graph())
27

28
module = tfhub.Module(MODEL_PATH)
29

30

31
text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
32
extra_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
33

34
# The encode_context signature now also takes the extra context.
35
context_encoding_tensor = module(
36
    {"context": text_placeholder, "extra_context": extra_text_placeholder}, signature="encode_context"
37
)
38

39

40
responce_text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
41

42
response_encoding_tensor = module(responce_text_placeholder, signature="encode_response")
43

44
sess.run(tf.tables_initializer())
45
sess.run(tf.global_variables_initializer())
46

47

48
def encode_contexts(dialog_history_batch):
49
    """Encode the dialog context to the response ranking vector space.
50

51
    Args:
52
        dialog_history: a list of strings, the dialog history, in
53
            chronological order.
54
    """
55

56
    # The context is the most recent message in the history.
57
    contexts = []
58
    extra_context_features = []
59

60
    for dialog_history in dialog_history_batch:
61
        contexts += [dialog_history[-1]]
62

63
        extra_context = list(dialog_history[:-1])
64
        extra_context.reverse()
65
        extra_context_features += [" ".join(extra_context)]
66

67
    return sess.run(
68
        context_encoding_tensor,
69
        feed_dict={text_placeholder: contexts, extra_text_placeholder: extra_context_features},
70
    )
71

72

73
def encode_responses(texts):
74
    return sess.run(response_encoding_tensor, feed_dict={responce_text_placeholder: texts})
75

76

77
def get_convert_score(contexts, responses):
78
    context_encodings = encode_contexts(contexts)
79
    response_encodings = encode_responses(responses)  # 79, 512
80
    res = np.multiply(context_encodings, response_encodings)
81
    return np.sum(res, axis=1).reshape(-1, 1)
82

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.