dream

Форк
0
83 строки · 2.3 Кб
1
# %%
2
import logging
3
import argparse
4
import pickle
5
import pathlib
6
import json
7

8
import tqdm
9
import tensorflow_hub as tfhub
10
import tensorflow as tf
11
import tensorflow_text
12
import numpy as np
13

14
tensorflow_text.__name__
15

16
logging.basicConfig(
17
    level=logging.INFO,
18
    format="%(asctime)s %(module)s %(lineno)d %(levelname)s : %(message)s",
19
    handlers=[
20
        logging.StreamHandler(),
21
        # logging.FileHandler('log.txt'),
22
    ],
23
)
24
logger = logging.getLogger(__name__)
25

26
if globals().get("get_ipython"):
27
    import sys
28

29
    sys.argv = [""]
30
    del sys
31

32
parser = argparse.ArgumentParser()
33
parser.add_argument(
34
    "--responses_file_path",
35
    type=pathlib.Path,
36
    help="Path to the json responses file",
37
    default="score_filtered_comments.json",
38
)
39
parser.add_argument(
40
    "--store_file_path",
41
    type=pathlib.Path,
42
    help="Store to a file of pickle format",
43
    default="replies.pkl",
44
)
45
parser.add_argument("--tfhub_model_dir_path", type=pathlib.Path, help="Path of a tfhub model dir", default="convert")
46
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
47
parser.add_argument("--embedded_key", type=str, default="reply")
48
parser.add_argument("--associative_value", type=str, default="reply")
49

50
args = parser.parse_args()
51

52
responses = [
53
    (req_res.get(args.embedded_key), req_res.get(args.associative_value))
54
    for req_res in json.load(args.responses_file_path.open())
55
    if req_res
56
]
57
responses = [(key, value) for key, value in responses if key and value]
58

59
sess = tf.InteractiveSession(graph=tf.Graph())
60

61
module = tfhub.Module(str(args.tfhub_model_dir_path))
62
text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
63

64
response_encoding_tensor = module(text_placeholder, signature="encode_response")
65

66
sess.run(tf.tables_initializer())
67
sess.run(tf.global_variables_initializer())
68

69

70
def encode_responses(texts):
71
    return sess.run(response_encoding_tensor, feed_dict={text_placeholder: texts})
72

73

74
keys, values = list(zip(*responses))
75
key_encodings = []
76
for i in tqdm.tqdm(range(0, len(keys), args.batch_size)):
77
    batch = keys[i : i + args.batch_size]
78
    key_encodings.append(encode_responses(batch))
79

80
key_encodings = np.concatenate(key_encodings)
81
logger.info(f"Encoded {key_encodings.shape[0]} candidate responses.")
82

83
pickle.dump((key_encodings, values), args.store_file_path.open("wb"))
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.