dream
114 строк · 3.8 Кб
1#!/usr/bin/env python
2
3import os
4import json
5import argparse
6import tensorflow as tf
7import tensorflow_hub as hub
8from collections import OrderedDict
9from utils import *
10
11MODEL_NAME = "linear_classifier"
12MULTILABEL = True
13TRAIN_SIZE = 0.5
14DENSE_LAYERS = 2
15MODEL_NAME += "_h" + str(DENSE_LAYERS)
16INTENT_DATA_PATH = "./intent_data_h" + str(DENSE_LAYERS) + ".json"
17
18parser = argparse.ArgumentParser()
19parser.add_argument(
20"--intent_phrases_path", help="file with phrases for embedding generation", default="intent_phrases.json"
21)
22parser.add_argument("--model_path", help="path where to save the model", default="./models/" + MODEL_NAME + ".h5")
23parser.add_argument("--epochs", help="number of epochs to train model", default=7)
24# Whereas to calc metrics or not (default value = True)
25args = parser.parse_args()
26
27# Create metrics directory if not exists
28if not os.path.exists("../metrics/"):
29os.makedirs("../metrics")
30
31USE_MODEL_PATH = os.environ.get("USE_MODEL_PATH", None)
32if USE_MODEL_PATH is None:
33USE_MODEL_PATH = "https://tfhub.dev/google/universal-sentence-encoder/1"
34
35TFHUB_CACHE_DIR = os.environ.get("TFHUB_CACHE_DIR", None)
36if TFHUB_CACHE_DIR is None:
37os.environ["TFHUB_CACHE_DIR"] = "../tfhub_model"
38
39
40def main():
41use = hub.Module(USE_MODEL_PATH)
42
43with open(args.intent_phrases_path, "r") as fp:
44all_data = json.load(fp)
45intent_phrases = OrderedDict(all_data["intent_phrases"])
46random_phrases = all_data["random_phrases"]
47
48intent_data = {}
49intents = sorted(list(intent_phrases.keys()))
50print("Creating data...")
51print("Intent: number of original phrases")
52with tf.compat.v1.Session() as sess:
53sess.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
54
55for intent, data in intent_phrases.items():
56phrases = generate_phrases(data["phrases"], data["punctuation"])
57intent_data[intent] = {
58"generated_phrases": phrases,
59"num_punctuation": len(data["punctuation"]),
60"min_precision": data["min_precision"],
61}
62print(f"{intent}: {len(phrases)//len(data['punctuation'])}")
63
64intent_embeddings_op = {
65intent: use(sentences["generated_phrases"]) for intent, sentences in intent_data.items()
66}
67
68random_preembedded = generate_phrases(random_phrases["phrases"], random_phrases["punctuation"])
69random_embeddings_op = use(random_preembedded)
70
71intent_embeddings = sess.run(intent_embeddings_op)
72random_embeddings = sess.run(random_embeddings_op)
73
74for intent in intents:
75intent_data[intent] = {
76"embeddings": intent_embeddings[intent].tolist(),
77"min_precision": intent_data[intent]["min_precision"],
78"num_punctuation": intent_data[intent]["num_punctuation"],
79}
80
81print("Created!")
82
83random_embeddings = random_embeddings.tolist()
84
85print("Scoring model...")
86
87metrics, thresholds = score_model(
88intent_data,
89intents,
90random_embeddings,
91samples=20,
92dense_layers=DENSE_LAYERS,
93epochs=int(args.epochs),
94train_size=TRAIN_SIZE,
95multilabel=MULTILABEL,
96)
97
98metrics.to_csv("../metrics/" + MODEL_NAME + "_metrics.csv")
99print("METRICS:")
100print(metrics)
101
102print("Training model...")
103train_data = get_train_data(intent_data, intents, random_embeddings, multilabel=MULTILABEL)
104model = get_linear_classifier(intents, dense_layers=DENSE_LAYERS, use_metrics=False, multilabel=MULTILABEL)
105
106model.fit(x=train_data["X"], y=train_data["y"], epochs=int(args.epochs))
107print(f"Saving model to: {args.model_path}")
108model.save(args.model_path)
109print(f"Saving thresholds to: {INTENT_DATA_PATH}")
110json.dump(thresholds, open(INTENT_DATA_PATH, "w"))
111
112
113if __name__ == "__main__":
114main()
115