dream
466 строк · 20.3 Кб
1import os2
3import helper4import numpy as np5import tensorflow as tf6import tensorflow_hub as hub7from nltk import word_tokenize8from tensorflow.contrib.layers import xavier_initializer, xavier_initializer_conv2d9
10
11class model:12def __init__(self, params, pretrained_model_path=""):13self.params = params14self.pretrained_model_path = pretrained_model_path15dicts = helper.load_dictionaries(self.params.dicts_file)16self.word2id = dicts["word2id"]17self.id2word = dicts["id2word"]18self.char2id = dicts["char2id"]19self.id2char = dicts["id2char"]20self.tag2id = dicts["tag2id"]21self.id2tag = dicts["id2tag"]22
23self.pretrained_emb = np.zeros(shape=(len(self.word2id), self.params.word_dim))24if self.pretrained_model_path == "" and self.params.train != "" and self.params.pretrained_emb != "":25self.pretrained_emb = helper.load_word_emb(self.word2id, self.pretrained_emb, self.params.pretrained_emb)26
27# build model28self.tf_word_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="word_ids")29self.tf_sentence_lengths = tf.placeholder(dtype=tf.int32, shape=[None], name="sentence_lengths")30self.tf_labels = tf.placeholder(dtype=tf.int32, shape=[None, None], name="labels")31self.tf_dropout = tf.placeholder(dtype=tf.float32, shape=[], name="drop_out")32self.tf_learning_rate = tf.placeholder(dtype=tf.float32, shape=[], name="learning_rate")33self.tf_char_ids = tf.placeholder(dtype=tf.int32, shape=[None, None, None], name="char_ids")34self.tf_word_lengths = tf.placeholder(dtype=tf.int32, shape=[None, None], name="word_lengths")35self.tf_raw_word = tf.placeholder(dtype=tf.string, shape=[None, None], name="raw_word")36
37with tf.variable_scope("word_embedding"):38tf_word_embeddings = tf.Variable(39self.pretrained_emb, dtype=tf.float32, trainable=True, name="word_embedding"40)41self.input = tf.nn.embedding_lookup(tf_word_embeddings, self.tf_word_ids, name="embedded_words")42
43with tf.variable_scope("char_cnn"):44tf_char_embeddings = tf.get_variable(45name="char_embeddings",46dtype=tf.float32,47shape=[len(self.char2id), self.params.char_dim],48trainable=True,49initializer=xavier_initializer(),50)51embedded_cnn_chars = tf.nn.embedding_lookup(tf_char_embeddings, self.tf_char_ids, name="embedded_cnn_chars")52conv1 = tf.layers.conv2d(53inputs=embedded_cnn_chars,54filters=self.params.nb_filters_1,55kernel_size=(1, 3),56strides=(1, 1),57padding="same",58name="conv1",59kernel_initializer=xavier_initializer_conv2d(),60)61conv2 = tf.layers.conv2d(62inputs=conv1,63filters=self.params.nb_filters_2,64kernel_size=(1, 3),65strides=(1, 1),66padding="same",67name="conv2",68kernel_initializer=xavier_initializer_conv2d(),69)70char_cnn = tf.reduce_max(conv2, axis=2)71self.input = tf.concat([self.input, char_cnn], axis=-1)72
73with tf.variable_scope("elmo_emb"):74elmo = hub.Module("/elmo2", trainable=False)75embeddings = elmo(76inputs={"tokens": self.tf_raw_word, "sequence_len": self.tf_sentence_lengths},77signature="tokens",78as_dict=True,79)[80"elmo"81] # num_sent, max_sent_len, 102482elmo_emb = tf.layers.dense(inputs=embeddings, units=self.params.elmo_dim, activation=None)83self.input = tf.concat([self.input, elmo_emb], axis=-1)84
85self.input = tf.nn.dropout(self.input, self.tf_dropout)86
87with tf.variable_scope("bi_lstm_words"):88cell_fw = tf.contrib.rnn.LSTMCell(self.params.word_hidden_size)89cell_bw = tf.contrib.rnn.LSTMCell(self.params.word_hidden_size)90(output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(91cell_fw, cell_bw, self.input, sequence_length=self.tf_sentence_lengths, dtype=tf.float3292)93self.output = tf.concat([output_fw, output_bw], axis=-1)94ntime_steps = tf.shape(self.output)[1]95self.output = tf.reshape(self.output, [-1, 2 * params.word_hidden_size])96layer1 = tf.nn.dropout(97tf.layers.dense(98inputs=self.output,99units=params.word_hidden_size,100activation=None,101kernel_initializer=xavier_initializer(),102),103self.tf_dropout,104)105pred = tf.layers.dense(106inputs=layer1, units=len(self.tag2id), activation=None, kernel_initializer=xavier_initializer()107)108self.logits = tf.reshape(pred, [-1, ntime_steps, len(self.tag2id)])109
110# compute loss value using crf111log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(112self.logits, self.tf_labels, self.tf_sentence_lengths113)114with tf.variable_scope("loss_and_opt"):115self.tf_loss = tf.reduce_mean(-log_likelihood)116optimizer = tf.train.AdamOptimizer(learning_rate=self.tf_learning_rate)117self.tf_train_op = optimizer.minimize(self.tf_loss)118
119def read_raw_data(self, raw_file_path, min_length_of_sentence):120# return raw_data{word, tag, pos, chunk}121word, word_, tag, tag_ = [], [], [], []122nb_part = 2123lines = open(file=raw_file_path, mode="r", encoding="utf8").readlines()124for line in lines:125if line.startswith("-DOCSTART-"):126continue127tokens = line.strip().split()128# end of the sentence129if len(tokens) == 0:130if len(word_) < min_length_of_sentence:131continue132word.append(word_)133tag.append(tag_)134word_, tag_ = [], []135continue136if len(tokens) < nb_part:137print("* input data is not valid:", line)138continue139word_.append(tokens[0])140tag_.append(tokens[-1])141raw_data = {"word": word, "tag": tag}142return raw_data143
144def index_data(self, raw_data):145# input: raw_data{word, tag}146# output: indexed_data{indexed_word, indexed_char, indexed_tag}147def low(x):148return x.lower() if self.params.lower == 1 else x149
150def zer(s):151return helper.zeros(s) if self.params.zeros == 1 else s152
153word = [[low(zer(x)) for x in s] for s in raw_data["word"]]154indexed_word = [[self.word2id[w] if w in self.word2id else self.word2id["<UNK>"] for w in s] for s in word]155indexed_data = {"indexed_word": indexed_word, "raw_word": raw_data["word"]}156if "tag" in raw_data:157indexed_tag = [[self.tag2id[t] for t in s] for s in raw_data["tag"]]158indexed_data["indexed_tag"] = indexed_tag159indexed_char = [160[[self.char2id[c] if c in self.char2id else self.char2id["<UNK>"] for c in zer(w)] for w in s]161for s in raw_data["word"]162]163indexed_data["indexed_char"] = indexed_char164return indexed_data165
166def get_batch(self, data, start_idx):167# input: data{indexed_word, indexed_char, indexed_tag, indexed_pos, indexed_chunk}168# output: a batch of data after padding169nb_sentences = len(data["indexed_word"])170end_idx = start_idx + self.params.batch_size171if end_idx > nb_sentences:172end_idx = nb_sentences173batch_word = data["indexed_word"][start_idx:end_idx]174if "indexed_tag" in data:175batch_tag = data["indexed_tag"][start_idx:end_idx]176batch_char = data["indexed_char"][start_idx:end_idx]177batch_raw_word = data["raw_word"][start_idx:end_idx]178real_sentence_lengths = [len(sent) for sent in batch_word]179max_len_sentences = max(real_sentence_lengths)180
181padded_word = [182np.lib.pad(183sent,184(0, max_len_sentences - len(sent)),185"constant",186constant_values=(self.word2id["<PAD>"], self.word2id["<PAD>"]),187)188for sent in batch_word189]190
191batch = {192"batch_word": batch_word,193"padded_word": padded_word,194"real_sentence_lengths": real_sentence_lengths,195"padded_raw_word": [sent + [""] * (max_len_sentences - len(sent)) for sent in batch_raw_word],196}197
198if "indexed_tag" in data:199padded_tag = [200np.lib.pad(201sent,202(0, max_len_sentences - len(sent)),203"constant",204constant_values=(self.tag2id["<PAD>"], self.tag2id["<PAD>"]),205)206for sent in batch_tag207]208batch["padded_tag"] = padded_tag209batch["batch_tag"] = batch_tag210
211# pad chars212max_len_of_sentence = max([len(sentence) for sentence in batch_char])213max_len_of_word = max([max([len(word) for word in sentence]) for sentence in batch_char])214
215padding_word = np.full(max_len_of_word, self.char2id["<PAD>"])216padded_char = []217
218lengths_of_word = []219
220for sentence in batch_char:221padded_sentence = []222length_of_word_in_sentence = []223
224for word in sentence:225length_of_word_in_sentence.append(len(word))226padded_sentence.append(227np.lib.pad(228word,229(0, max_len_of_word - len(word)),230"constant",231constant_values=(self.char2id["<PAD>"], self.char2id["<PAD>"]),232)233)234
235for i in range(max_len_of_sentence - len(padded_sentence)):236padded_sentence.append(padding_word)237length_of_word_in_sentence.append(0)238
239padded_char.append(padded_sentence)240lengths_of_word.append(length_of_word_in_sentence)241
242lengths_of_word = np.array(lengths_of_word)243
244batch["padded_char"] = padded_char245batch["lengths_of_word"] = lengths_of_word246
247return batch, end_idx248
249def train(self, training_file_path, val_file_path, output_model_path=None, nb_epochs=20, init_model_path=None):250raw_train_data = self.read_raw_data(raw_file_path=training_file_path, min_length_of_sentence=2)251raw_val_data = self.read_raw_data(raw_file_path=val_file_path, min_length_of_sentence=2)252
253indexed_train_data = self.index_data(raw_train_data)254indexed_val_data = self.index_data(raw_val_data)255
256saver = tf.train.Saver()257best_f1 = 0258with tf.Session() as sess:259sess.run(tf.global_variables_initializer())260# reload model261if self.pretrained_model_path != "":262print("Model is being reloaded from {:}".format(self.pretrained_model_path))263saver.restore(sess, self.pretrained_model_path + "/model")264_, best_f1 = self.evaluate(sess, indexed_val_data)265print("The best f1: {:06.05}".format(best_f1))266
267# train model268print(self.params)269for epoch in range(nb_epochs):270# shuffle data271random_indexes = np.random.permutation(len(indexed_train_data["indexed_word"]))272data = {}273for i in indexed_train_data:274data[i] = [indexed_train_data[i][j] for j in random_indexes]275
276losses_of_batches = []277current_idx = 0278while current_idx < len(data["indexed_word"]):279batch, current_idx = self.get_batch(data, current_idx)280feed_dict = {281self.tf_word_ids: batch["padded_word"],282self.tf_sentence_lengths: batch["real_sentence_lengths"],283self.tf_labels: batch["padded_tag"],284self.tf_learning_rate: self.params.learning_rate,285self.tf_dropout: self.params.dropout,286self.tf_char_ids: batch["padded_char"],287self.tf_word_lengths: batch["lengths_of_word"],288self.tf_raw_word: batch["padded_raw_word"],289}290
291_, train_loss = sess.run([self.tf_train_op, self.tf_loss], feed_dict=feed_dict)292losses_of_batches.append(train_loss)293
294mean_loss = np.mean(losses_of_batches)295
296# evaluate model on the dev set297acc, f1 = self.evaluate(sess, indexed_val_data)298if f1 > best_f1:299best_f1 = f1300if output_model_path is not None:301saver.save(sess, output_model_path)302print(303"Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f} (*).".format(304epoch, mean_loss, acc, f1305)306)307else:308print(309"Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f} (*).".format(310epoch, mean_loss, acc, f1311)312)313else:314print(315"Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f}".format(316epoch, mean_loss, acc, f1317)318)319print("Training finished.")320
321def evaluate(self, sess, data):322accs = []323correct_preds, total_correct, total_preds = 0.0, 0.0, 0.0324current_idx = 0325while current_idx < len(data["indexed_word"]):326batch, current_idx = self.get_batch(data, current_idx)327# decode using Viterbi algorithm328viterbi_sequences = []329feed_dict = {330self.tf_word_ids: batch["padded_word"],331self.tf_sentence_lengths: batch["real_sentence_lengths"],332self.tf_dropout: 1.0,333self.tf_char_ids: batch["padded_char"],334self.tf_word_lengths: batch["lengths_of_word"],335self.tf_raw_word: batch["padded_raw_word"],336}337_logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)338
339# iterate over the sentences340for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):341# keep only the valid time steps342_logit = _logit[:sequence_length]343viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)344viterbi_sequences += [viterbi_sequence]345
346for lab, lab_pred in zip(batch["batch_tag"], viterbi_sequences):347accs += [a == b for (a, b) in zip(lab, lab_pred)]348lab_chunks = set(helper.get_chunks(lab, self.tag2id))349lab_pred_chunks = set(helper.get_chunks(lab_pred, self.tag2id))350correct_preds += len(lab_chunks & lab_pred_chunks)351total_preds += len(lab_pred_chunks)352total_correct += len(lab_chunks)353
354p = correct_preds / total_preds if correct_preds > 0 else 0355r = correct_preds / total_correct if correct_preds > 0 else 0356f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0357acc = np.mean(accs)358return acc, f1359
360def evaluate_using_conlleval(361self, model_path, testing_file_path, output_folder, min_length_of_sentence=0, show_score_file=True362):363output_file = os.path.join(output_folder, "result.txt")364score_file = os.path.join(output_folder, "score.txt")365
366raw_data = self.read_raw_data(raw_file_path=testing_file_path, min_length_of_sentence=1)367indexed_data = self.index_data(raw_data)368
369f_out = open(output_file, "w", encoding="utf8")370saver = tf.train.Saver()371with tf.Session() as sess:372saver.restore(sess, model_path)373current_idx = 0374while current_idx < len(indexed_data["indexed_word"]):375batch, current_idx = self.get_batch(indexed_data, current_idx)376
377# decode using Viterbi algorithm378viterbi_sequences = []379feed_dict = {380self.tf_word_ids: batch["padded_word"],381self.tf_sentence_lengths: batch["real_sentence_lengths"],382self.tf_dropout: 1.0,383self.tf_char_ids: batch["padded_char"],384self.tf_word_lengths: batch["lengths_of_word"],385self.tf_raw_word: batch["padded_raw_word"],386}387_logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)388
389# iterate over the sentences390for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):391# keep only the valid time steps392_logit = _logit[:sequence_length]393viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)394viterbi_sequences += [viterbi_sequence]395
396for words, labs, lab_preds in zip(batch["batch_word"], batch["batch_tag"], viterbi_sequences):397for word, lab, lab_pred in zip(words, labs, lab_preds):398f_out.write("{:} {:} {:}\n".format(self.id2word[word], self.id2tag[lab], self.id2tag[lab_pred]))399f_out.write("\n")400
401f_out.close()402os.system('perl "%s" < "%s" > "%s"' % ("conlleval", output_file, score_file))403print("Tagging output and testing results were written to " + output_file + " and " + score_file)404
405if show_score_file:406print("Score on {} calculated by ConllEval:".format(testing_file_path))407f_in = open(score_file, mode="r", encoding="utf8")408for line in f_in.readlines():409print(line)410
411def predict(self, sess, text):412inp_sent = text413
414if inp_sent == "":415return ""416
417for p in [".", "?", "!"]:418if p in inp_sent:419return inp_sent420
421words = word_tokenize(inp_sent)422
423raw_data = {"word": [words]}424
425indexed_data = self.index_data(raw_data)426
427current_idx = 0428while current_idx < len(indexed_data["indexed_word"]):429batch, current_idx = self.get_batch(indexed_data, current_idx)430
431# decode using Viterbi algorithm432viterbi_sequences = []433feed_dict = {434self.tf_word_ids: batch["padded_word"],435self.tf_sentence_lengths: batch["real_sentence_lengths"],436self.tf_dropout: 1.0,437self.tf_char_ids: batch["padded_char"],438self.tf_word_lengths: batch["lengths_of_word"],439self.tf_raw_word: batch["padded_raw_word"],440}441_logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)442
443# iterate over the sentences444for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):445# keep only the valid time steps446_logit = _logit[:sequence_length]447viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)448viterbi_sequences += [viterbi_sequence]449
450# pred_labels = [[self.id2tag[t] for t in s] for s in viterbi_sequences]451pred_labels = [self.id2tag[t] for t in viterbi_sequences[0]]452# print("Pred_labels: ",pred_labels)453
454tag2text = {"B-S": ".", "B-Q": "?", "O": "."}455
456punctuation = tag2text[pred_labels[0]]457sent = words[0]458
459for word, tag in zip(words[1:], pred_labels[1:]):460if tag != "O":461sent += punctuation462punctuation = tag2text[tag]463sent += " " + word464sent += punctuation465
466return sent467