dream
83 строки · 2.3 Кб
1# %%
2import logging
3import argparse
4import pickle
5import pathlib
6import json
7
8import tqdm
9import tensorflow_hub as tfhub
10import tensorflow as tf
11import tensorflow_text
12import numpy as np
13
14tensorflow_text.__name__
15
16logging.basicConfig(
17level=logging.INFO,
18format="%(asctime)s %(module)s %(lineno)d %(levelname)s : %(message)s",
19handlers=[
20logging.StreamHandler(),
21# logging.FileHandler('log.txt'),
22],
23)
24logger = logging.getLogger(__name__)
25
26if globals().get("get_ipython"):
27import sys
28
29sys.argv = [""]
30del sys
31
32parser = argparse.ArgumentParser()
33parser.add_argument(
34"--responses_file_path",
35type=pathlib.Path,
36help="Path to the json responses file",
37default="score_filtered_comments.json",
38)
39parser.add_argument(
40"--store_file_path",
41type=pathlib.Path,
42help="Store to a file of pickle format",
43default="replies.pkl",
44)
45parser.add_argument("--tfhub_model_dir_path", type=pathlib.Path, help="Path of a tfhub model dir", default="convert")
46parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
47parser.add_argument("--embedded_key", type=str, default="reply")
48parser.add_argument("--associative_value", type=str, default="reply")
49
50args = parser.parse_args()
51
52responses = [
53(req_res.get(args.embedded_key), req_res.get(args.associative_value))
54for req_res in json.load(args.responses_file_path.open())
55if req_res
56]
57responses = [(key, value) for key, value in responses if key and value]
58
59sess = tf.InteractiveSession(graph=tf.Graph())
60
61module = tfhub.Module(str(args.tfhub_model_dir_path))
62text_placeholder = tf.placeholder(dtype=tf.string, shape=[None])
63
64response_encoding_tensor = module(text_placeholder, signature="encode_response")
65
66sess.run(tf.tables_initializer())
67sess.run(tf.global_variables_initializer())
68
69
70def encode_responses(texts):
71return sess.run(response_encoding_tensor, feed_dict={text_placeholder: texts})
72
73
74keys, values = list(zip(*responses))
75key_encodings = []
76for i in tqdm.tqdm(range(0, len(keys), args.batch_size)):
77batch = keys[i : i + args.batch_size]
78key_encodings.append(encode_responses(batch))
79
80key_encodings = np.concatenate(key_encodings)
81logger.info(f"Encoded {key_encodings.shape[0]} candidate responses.")
82
83pickle.dump((key_encodings, values), args.store_file_path.open("wb"))
84