dream

Форк
0
/
create_data_and_train_model.py 
114 строк · 3.8 Кб
1
#!/usr/bin/env python
2

3
import os
4
import json
5
import argparse
6
import tensorflow as tf
7
import tensorflow_hub as hub
8
from collections import OrderedDict
9
from utils import *
10

11
MODEL_NAME = "linear_classifier"
12
MULTILABEL = True
13
TRAIN_SIZE = 0.5
14
DENSE_LAYERS = 2
15
MODEL_NAME += "_h" + str(DENSE_LAYERS)
16
INTENT_DATA_PATH = "./intent_data_h" + str(DENSE_LAYERS) + ".json"
17

18
parser = argparse.ArgumentParser()
19
parser.add_argument(
20
    "--intent_phrases_path", help="file with phrases for embedding generation", default="intent_phrases.json"
21
)
22
parser.add_argument("--model_path", help="path where to save the model", default="./models/" + MODEL_NAME + ".h5")
23
parser.add_argument("--epochs", help="number of epochs to train model", default=7)
24
# Whereas to calc metrics or not (default value = True)
25
args = parser.parse_args()
26

27
# Create metrics directory if not exists
28
if not os.path.exists("../metrics/"):
29
    os.makedirs("../metrics")
30

31
USE_MODEL_PATH = os.environ.get("USE_MODEL_PATH", None)
32
if USE_MODEL_PATH is None:
33
    USE_MODEL_PATH = "https://tfhub.dev/google/universal-sentence-encoder/1"
34

35
TFHUB_CACHE_DIR = os.environ.get("TFHUB_CACHE_DIR", None)
36
if TFHUB_CACHE_DIR is None:
37
    os.environ["TFHUB_CACHE_DIR"] = "../tfhub_model"
38

39

40
def main():
41
    use = hub.Module(USE_MODEL_PATH)
42

43
    with open(args.intent_phrases_path, "r") as fp:
44
        all_data = json.load(fp)
45
        intent_phrases = OrderedDict(all_data["intent_phrases"])
46
        random_phrases = all_data["random_phrases"]
47

48
    intent_data = {}
49
    intents = sorted(list(intent_phrases.keys()))
50
    print("Creating  data...")
51
    print("Intent: number of original phrases")
52
    with tf.compat.v1.Session() as sess:
53
        sess.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
54

55
        for intent, data in intent_phrases.items():
56
            phrases = generate_phrases(data["phrases"], data["punctuation"])
57
            intent_data[intent] = {
58
                "generated_phrases": phrases,
59
                "num_punctuation": len(data["punctuation"]),
60
                "min_precision": data["min_precision"],
61
            }
62
            print(f"{intent}: {len(phrases)//len(data['punctuation'])}")
63

64
        intent_embeddings_op = {
65
            intent: use(sentences["generated_phrases"]) for intent, sentences in intent_data.items()
66
        }
67

68
        random_preembedded = generate_phrases(random_phrases["phrases"], random_phrases["punctuation"])
69
        random_embeddings_op = use(random_preembedded)
70

71
        intent_embeddings = sess.run(intent_embeddings_op)
72
        random_embeddings = sess.run(random_embeddings_op)
73

74
        for intent in intents:
75
            intent_data[intent] = {
76
                "embeddings": intent_embeddings[intent].tolist(),
77
                "min_precision": intent_data[intent]["min_precision"],
78
                "num_punctuation": intent_data[intent]["num_punctuation"],
79
            }
80

81
    print("Created!")
82

83
    random_embeddings = random_embeddings.tolist()
84

85
    print("Scoring model...")
86

87
    metrics, thresholds = score_model(
88
        intent_data,
89
        intents,
90
        random_embeddings,
91
        samples=20,
92
        dense_layers=DENSE_LAYERS,
93
        epochs=int(args.epochs),
94
        train_size=TRAIN_SIZE,
95
        multilabel=MULTILABEL,
96
    )
97

98
    metrics.to_csv("../metrics/" + MODEL_NAME + "_metrics.csv")
99
    print("METRICS:")
100
    print(metrics)
101

102
    print("Training model...")
103
    train_data = get_train_data(intent_data, intents, random_embeddings, multilabel=MULTILABEL)
104
    model = get_linear_classifier(intents, dense_layers=DENSE_LAYERS, use_metrics=False, multilabel=MULTILABEL)
105

106
    model.fit(x=train_data["X"], y=train_data["y"], epochs=int(args.epochs))
107
    print(f"Saving model to: {args.model_path}")
108
    model.save(args.model_path)
109
    print(f"Saving thresholds to: {INTENT_DATA_PATH}")
110
    json.dump(thresholds, open(INTENT_DATA_PATH, "w"))
111

112

113
if __name__ == "__main__":
114
    main()
115

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.