dream
63 строки · 1.6 Кб
1#!/usr/bin/env python
2
3import os4import json5import tensorflow as tf6import tensorflow_hub as hub7import numpy as np8from utils import cosine_similarity_debug9
10
11INTENT_DATA_PATH = "./data/intent_data.json"12
13MODEL_PATH = os.environ.get("USE_MODEL_PATH", None)14if MODEL_PATH is None:15MODEL_PATH = "https://tfhub.dev/google/universal-sentence-encoder/1"16
17TFHUB_CACHE_DIR = os.environ.get("TFHUB_CACHE_DIR", None)18if TFHUB_CACHE_DIR is None:19os.environ["TFHUB_CACHE_DIR"] = "../tfhub_model"20
21PHRASES = [22"Okay",23"Okay, Alexa",24"Bye, Alexa",25"Alexa, bye",26"Goodbye, Alexa",27"Goodbye, bot",28"Bot, goodbye",29"Bye, bot",30"Have a nice one",31"Hello",32"Hi",33"Hello, bot",34"Hello, Alexa",35"Hi, bot",36"Hi, Alexa",37"Hey, Alexa",38"Okay, have a good day!",39"Have a good day, Alexa",40"Okay, Alexa, have a good day",41]
42INTENT = "exit"43
44
45def main():46model = hub.Module(MODEL_PATH)47
48intent_data = json.load(open(INTENT_DATA_PATH))[INTENT]49embedded_phrases = model(PHRASES)50intent_phrases = np.array(intent_data["phrases"])51threshold = intent_data["threshold"]52intent_embeddings = tf.constant(intent_data["embeddings"], dtype=tf.float32)53sim = cosine_similarity_debug(embedded_phrases, intent_embeddings)54with tf.compat.v1.Session() as sess:55sess.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])56values, similiarity_ids = sess.run(sim)57
58for u in list(zip(PHRASES, intent_phrases[similiarity_ids], values, [threshold] * len(PHRASES))):59print(u)60
61
62if __name__ == "__main__":63main()64