dream
217 строк · 8.6 Кб
1#!/usr/bin/env python
2
3import tensorflow as tf4import numpy as np5import random6import pandas as pd7from itertools import chain8from tqdm import tqdm9from xeger import Xeger10from sklearn.metrics import precision_recall_curve11
12
13def tb_accuracy(y_true, y_pred):14y_true = tf.math.argmax(y_true, dimension=1)15y_pred = tf.math.argmax(y_pred, dimension=1)16return tf.keras.metrics.Accuracy()(y_true, y_pred)17
18
19def tb_f1(y_true, y_pred):20precision = tf.keras.metrics.Precision()(y_true, y_pred)21recall = tf.keras.metrics.Recall()(y_true, y_pred)22return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))23
24
25def multilabel_precision(y_true, y_pred):26"""27Macro-precision, with thresholds defined by by argmax F1
28"""
29values = list()30for i in range(y_true.get_shape()[1]):31pr, rec, thresholds = precision_recall_curve(y_true[:, i], y_pred[:, i])32f1 = 2.0 * pr * rec / (pr + rec)33values.append(pr[np.argmax(f1)])34return np.mean(values)35
36
37def multilabel_recall(y_true, y_pred):38"""39Macro-recall, with thresholds defined by argmax F1
40"""
41values = list()42for i in range(y_true.get_shape()[1]):43pr, rec, thresholds = precision_recall_curve(y_true[:, i], y_pred[:, i])44f1 = 2.0 * pr * rec / (pr + rec)45values.append(rec[np.argmax(f1)])46return np.mean(values)47
48
49def multilabel_f1(y_true, y_pred):50"""51Macro-F1, with thresholds defined by argmax F1
52"""
53values = list()54for i in range(y_true.shape[1]):55pr, rec, thresholds = precision_recall_curve(y_true[:, i], y_pred[:, i])56f1 = 2.0 * pr * rec / (pr + rec)57values.append(np.max(f1))58return np.mean(values)59
60
61def calculate_metrics(intents_min_pr, y_true, y_pred):62intent_data = dict()63for i, intent in enumerate(intents_min_pr):64pr, rec, thresholds = precision_recall_curve(y_true[:, i], y_pred[:, i])65f1 = 2.0 * pr * rec / (pr + rec)66indx = np.argwhere(pr > intents_min_pr[intent]).reshape(-1)67# Argmax F1(threshold) where precision is greater than smth68indx = indx[np.argmax(f1[indx])]69intent_data[intent] = {70"threshold": thresholds[indx],71"precision": pr[indx],72"recall": rec[indx],73"f1": f1[indx],74}75return intent_data76
77
78def generate_phrases(template_re, punctuation, limit=2500):79x = Xeger(limit=limit)80phrases = []81for regex in template_re:82try:83phrases += list({x.xeger(regex) for _ in range(limit)})84except Exception as e:85print(e)86print(regex)87raise e88phrases = [phrases] + [[phrase + punct for phrase in phrases] for punct in punctuation]89return list(chain.from_iterable(phrases))90
91
92def get_linear_classifier(intents, input_dim=512, dense_layers=1, use_metrics=True, multilabel=False):93if multilabel:94units = len(intents)95activation = "sigmoid"96metrics = [] if not use_metrics else ["binary_crossentropy"]97else:98units = len(intents) + 199activation = "softmax"100metrics = (101[] if not use_metrics else [tb_accuracy, tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tb_f1]102)103model = [104tf.keras.layers.Dense(units=256, activation="relu", input_dim=input_dim if i == 0 else 256)105for i in range(dense_layers)106] # Hidden dense layers107model += [108tf.keras.layers.Dense(units=units, activation=activation, input_dim=input_dim if not len(model) else 256)109] # Output layer110model = tf.keras.Sequential(model)111model.compile(112optimizer=tf.keras.optimizers.Adam(),113loss="categorical_crossentropy" if not multilabel else "binary_crossentropy",114metrics=metrics,115)116return model117
118
119def train_test_split(full_length, punct_num, train_size):120original_length = full_length // (punct_num + 1)121# Number of original phrases122train_length = int(original_length * train_size)123
124# Getting indexies125train_idx = random.sample(list(range(original_length)), train_length)126test_idx = list(set(range(original_length)) - set(train_idx))127
128# Upsampling129# train_length = max(train_length, 800)130# test_length = max(test_length, 800)131# train_idx = np.random.choice(train_idx, train_length)132# test_idx = np.random.choice(test_idx, test_length)133
134# With punctuation variants135train_idx = list(chain.from_iterable([[i + original_length * p for i in train_idx] for p in range(punct_num + 1)]))136test_idx = list(chain.from_iterable([[i + original_length * p for i in test_idx] for p in range(punct_num + 1)]))137return train_idx, test_idx138
139
140def get_train_test_data(data, intents, random_phrases_embeddings, multilabel=False, train_size=0.8):141train_data = {"X": [], "y": []}142test_data = {"X": [], "y": []}143num_classes = len(intents) + 1 if not multilabel else len(intents)144for i, intent in enumerate(intents):145train_idx, test_idx = train_test_split(146len(data[intent]["embeddings"]), data[intent]["num_punctuation"], train_size=train_size147)148train = np.array(data[intent]["embeddings"])[train_idx]149test = np.array(data[intent]["embeddings"])[test_idx]150train_data["X"].append(train)151train_data["y"].append([[1.0 if j == i else 0.0 for j in range(num_classes)] for _ in range(len(train))])152test_data["X"].append(test)153test_data["y"].append([[1.0 if j == i else 0.0 for j in range(num_classes)] for _ in range(len(test))])154
155train_data["X"].append(random_phrases_embeddings)156train_data["y"].append(157[[1.0 if j == len(intents) else 0.0 for j in range(num_classes)] for _ in range(len(random_phrases_embeddings))]158)159
160train_data["X"] = np.concatenate(train_data["X"])161test_data["X"] = np.concatenate(test_data["X"])162train_data["y"] = np.concatenate(train_data["y"])163test_data["y"] = np.concatenate(test_data["y"])164return train_data, test_data165
166
167def get_train_data(data, intents, random_phrases_embeddings, multilabel=False):168train_data = {"X": [], "y": []}169num_classes = len(intents) + 1 if not multilabel else len(intents)170for i, intent in enumerate(intents):171train = np.array(data[intent]["embeddings"])172train_data["X"].append(train)173train_data["y"].append([[1.0 if j == i else 0.0 for j in range(num_classes)] for _ in range(len(train))])174
175train_data["X"].append(random_phrases_embeddings)176train_data["y"].append(177[[1.0 if j == len(intents) else 0.0 for j in range(num_classes)] for _ in range(len(random_phrases_embeddings))]178)179
180train_data["X"] = np.concatenate(train_data["X"])181train_data["y"] = np.concatenate(train_data["y"])182return train_data183
184
185def score_model(186data, intents, random_phrases_embeddings, samples=20, dense_layers=1, train_size=0.5, epochs=80, multilabel=False187):188metrics = {intent: {"precision": [], "recall": [], "f1": [], "threshold": []} for intent in intents}189intents_min_pr = {intent: v["min_precision"] for intent, v in data.items()}190for _ in tqdm(range(samples)):191model = get_linear_classifier(intents=intents, dense_layers=dense_layers, multilabel=multilabel)192train_data, test_data = get_train_test_data(193data, intents, random_phrases_embeddings, multilabel=multilabel, train_size=train_size194)195model.fit(x=train_data["X"], y=train_data["y"], epochs=epochs, verbose=0)196
197current_metrics = calculate_metrics(intents_min_pr, test_data["y"], model.predict(test_data["X"]))198for intent in current_metrics:199for metric_name in current_metrics[intent]:200metrics[intent][metric_name].append(current_metrics[intent][metric_name])201for intent in intents:202precision = (np.mean(metrics[intent]["precision"]), np.std(metrics[intent]["precision"]))203recall = (np.mean(metrics[intent]["recall"]), np.std(metrics[intent]["recall"]))204f1 = (np.mean(metrics[intent]["f1"]), np.std(metrics[intent]["f1"]))205threshold = (np.mean(metrics[intent]["threshold"]), np.std(metrics[intent]["threshold"]))206message = (207f"\nIntent: {intent}\n"208+ f"PRECISION: {precision[0]}±{precision[1]}\n"209+ f"RECALL: {recall[0]}±{recall[1]}\n"210+ f"F1: {f1[0]}±{f1[1]}\n"211+ f"Threshold: {threshold[0]}±{threshold[1]}\n\n"212)213print(message)214metrics = {intent: {metric: np.mean(metrics[intent][metric]) for metric in metrics[intent]} for intent in metrics}215thresholds = {intent: float(np.mean(metrics[intent]["threshold"])) for intent in metrics}216metrics = pd.DataFrame.from_dict(metrics)217return metrics, thresholds218