dream

Форк
0
/
sentsegmodel.py 
466 строк · 20.3 Кб
1
import os
2

3
import helper
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow_hub as hub
7
from nltk import word_tokenize
8
from tensorflow.contrib.layers import xavier_initializer, xavier_initializer_conv2d
9

10

11
class model:
12
    def __init__(self, params, pretrained_model_path=""):
13
        self.params = params
14
        self.pretrained_model_path = pretrained_model_path
15
        dicts = helper.load_dictionaries(self.params.dicts_file)
16
        self.word2id = dicts["word2id"]
17
        self.id2word = dicts["id2word"]
18
        self.char2id = dicts["char2id"]
19
        self.id2char = dicts["id2char"]
20
        self.tag2id = dicts["tag2id"]
21
        self.id2tag = dicts["id2tag"]
22

23
        self.pretrained_emb = np.zeros(shape=(len(self.word2id), self.params.word_dim))
24
        if self.pretrained_model_path == "" and self.params.train != "" and self.params.pretrained_emb != "":
25
            self.pretrained_emb = helper.load_word_emb(self.word2id, self.pretrained_emb, self.params.pretrained_emb)
26

27
        # build model
28
        self.tf_word_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="word_ids")
29
        self.tf_sentence_lengths = tf.placeholder(dtype=tf.int32, shape=[None], name="sentence_lengths")
30
        self.tf_labels = tf.placeholder(dtype=tf.int32, shape=[None, None], name="labels")
31
        self.tf_dropout = tf.placeholder(dtype=tf.float32, shape=[], name="drop_out")
32
        self.tf_learning_rate = tf.placeholder(dtype=tf.float32, shape=[], name="learning_rate")
33
        self.tf_char_ids = tf.placeholder(dtype=tf.int32, shape=[None, None, None], name="char_ids")
34
        self.tf_word_lengths = tf.placeholder(dtype=tf.int32, shape=[None, None], name="word_lengths")
35
        self.tf_raw_word = tf.placeholder(dtype=tf.string, shape=[None, None], name="raw_word")
36

37
        with tf.variable_scope("word_embedding"):
38
            tf_word_embeddings = tf.Variable(
39
                self.pretrained_emb, dtype=tf.float32, trainable=True, name="word_embedding"
40
            )
41
            self.input = tf.nn.embedding_lookup(tf_word_embeddings, self.tf_word_ids, name="embedded_words")
42

43
        with tf.variable_scope("char_cnn"):
44
            tf_char_embeddings = tf.get_variable(
45
                name="char_embeddings",
46
                dtype=tf.float32,
47
                shape=[len(self.char2id), self.params.char_dim],
48
                trainable=True,
49
                initializer=xavier_initializer(),
50
            )
51
            embedded_cnn_chars = tf.nn.embedding_lookup(tf_char_embeddings, self.tf_char_ids, name="embedded_cnn_chars")
52
            conv1 = tf.layers.conv2d(
53
                inputs=embedded_cnn_chars,
54
                filters=self.params.nb_filters_1,
55
                kernel_size=(1, 3),
56
                strides=(1, 1),
57
                padding="same",
58
                name="conv1",
59
                kernel_initializer=xavier_initializer_conv2d(),
60
            )
61
            conv2 = tf.layers.conv2d(
62
                inputs=conv1,
63
                filters=self.params.nb_filters_2,
64
                kernel_size=(1, 3),
65
                strides=(1, 1),
66
                padding="same",
67
                name="conv2",
68
                kernel_initializer=xavier_initializer_conv2d(),
69
            )
70
            char_cnn = tf.reduce_max(conv2, axis=2)
71
            self.input = tf.concat([self.input, char_cnn], axis=-1)
72

73
        with tf.variable_scope("elmo_emb"):
74
            elmo = hub.Module("/elmo2", trainable=False)
75
            embeddings = elmo(
76
                inputs={"tokens": self.tf_raw_word, "sequence_len": self.tf_sentence_lengths},
77
                signature="tokens",
78
                as_dict=True,
79
            )[
80
                "elmo"
81
            ]  # num_sent, max_sent_len, 1024
82
            elmo_emb = tf.layers.dense(inputs=embeddings, units=self.params.elmo_dim, activation=None)
83
            self.input = tf.concat([self.input, elmo_emb], axis=-1)
84

85
        self.input = tf.nn.dropout(self.input, self.tf_dropout)
86

87
        with tf.variable_scope("bi_lstm_words"):
88
            cell_fw = tf.contrib.rnn.LSTMCell(self.params.word_hidden_size)
89
            cell_bw = tf.contrib.rnn.LSTMCell(self.params.word_hidden_size)
90
            (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(
91
                cell_fw, cell_bw, self.input, sequence_length=self.tf_sentence_lengths, dtype=tf.float32
92
            )
93
            self.output = tf.concat([output_fw, output_bw], axis=-1)
94
            ntime_steps = tf.shape(self.output)[1]
95
            self.output = tf.reshape(self.output, [-1, 2 * params.word_hidden_size])
96
            layer1 = tf.nn.dropout(
97
                tf.layers.dense(
98
                    inputs=self.output,
99
                    units=params.word_hidden_size,
100
                    activation=None,
101
                    kernel_initializer=xavier_initializer(),
102
                ),
103
                self.tf_dropout,
104
            )
105
            pred = tf.layers.dense(
106
                inputs=layer1, units=len(self.tag2id), activation=None, kernel_initializer=xavier_initializer()
107
            )
108
            self.logits = tf.reshape(pred, [-1, ntime_steps, len(self.tag2id)])
109

110
            # compute loss value using crf
111
            log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
112
                self.logits, self.tf_labels, self.tf_sentence_lengths
113
            )
114
        with tf.variable_scope("loss_and_opt"):
115
            self.tf_loss = tf.reduce_mean(-log_likelihood)
116
            optimizer = tf.train.AdamOptimizer(learning_rate=self.tf_learning_rate)
117
            self.tf_train_op = optimizer.minimize(self.tf_loss)
118

119
    def read_raw_data(self, raw_file_path, min_length_of_sentence):
120
        # return raw_data{word, tag, pos, chunk}
121
        word, word_, tag, tag_ = [], [], [], []
122
        nb_part = 2
123
        lines = open(file=raw_file_path, mode="r", encoding="utf8").readlines()
124
        for line in lines:
125
            if line.startswith("-DOCSTART-"):
126
                continue
127
            tokens = line.strip().split()
128
            # end of the sentence
129
            if len(tokens) == 0:
130
                if len(word_) < min_length_of_sentence:
131
                    continue
132
                word.append(word_)
133
                tag.append(tag_)
134
                word_, tag_ = [], []
135
                continue
136
            if len(tokens) < nb_part:
137
                print("* input data is not valid:", line)
138
                continue
139
            word_.append(tokens[0])
140
            tag_.append(tokens[-1])
141
        raw_data = {"word": word, "tag": tag}
142
        return raw_data
143

144
    def index_data(self, raw_data):
145
        # input: raw_data{word, tag}
146
        # output: indexed_data{indexed_word, indexed_char, indexed_tag}
147
        def low(x):
148
            return x.lower() if self.params.lower == 1 else x
149

150
        def zer(s):
151
            return helper.zeros(s) if self.params.zeros == 1 else s
152

153
        word = [[low(zer(x)) for x in s] for s in raw_data["word"]]
154
        indexed_word = [[self.word2id[w] if w in self.word2id else self.word2id["<UNK>"] for w in s] for s in word]
155
        indexed_data = {"indexed_word": indexed_word, "raw_word": raw_data["word"]}
156
        if "tag" in raw_data:
157
            indexed_tag = [[self.tag2id[t] for t in s] for s in raw_data["tag"]]
158
            indexed_data["indexed_tag"] = indexed_tag
159
        indexed_char = [
160
            [[self.char2id[c] if c in self.char2id else self.char2id["<UNK>"] for c in zer(w)] for w in s]
161
            for s in raw_data["word"]
162
        ]
163
        indexed_data["indexed_char"] = indexed_char
164
        return indexed_data
165

166
    def get_batch(self, data, start_idx):
167
        # input: data{indexed_word, indexed_char, indexed_tag, indexed_pos, indexed_chunk}
168
        # output: a batch of data after padding
169
        nb_sentences = len(data["indexed_word"])
170
        end_idx = start_idx + self.params.batch_size
171
        if end_idx > nb_sentences:
172
            end_idx = nb_sentences
173
        batch_word = data["indexed_word"][start_idx:end_idx]
174
        if "indexed_tag" in data:
175
            batch_tag = data["indexed_tag"][start_idx:end_idx]
176
        batch_char = data["indexed_char"][start_idx:end_idx]
177
        batch_raw_word = data["raw_word"][start_idx:end_idx]
178
        real_sentence_lengths = [len(sent) for sent in batch_word]
179
        max_len_sentences = max(real_sentence_lengths)
180

181
        padded_word = [
182
            np.lib.pad(
183
                sent,
184
                (0, max_len_sentences - len(sent)),
185
                "constant",
186
                constant_values=(self.word2id["<PAD>"], self.word2id["<PAD>"]),
187
            )
188
            for sent in batch_word
189
        ]
190

191
        batch = {
192
            "batch_word": batch_word,
193
            "padded_word": padded_word,
194
            "real_sentence_lengths": real_sentence_lengths,
195
            "padded_raw_word": [sent + [""] * (max_len_sentences - len(sent)) for sent in batch_raw_word],
196
        }
197

198
        if "indexed_tag" in data:
199
            padded_tag = [
200
                np.lib.pad(
201
                    sent,
202
                    (0, max_len_sentences - len(sent)),
203
                    "constant",
204
                    constant_values=(self.tag2id["<PAD>"], self.tag2id["<PAD>"]),
205
                )
206
                for sent in batch_tag
207
            ]
208
            batch["padded_tag"] = padded_tag
209
            batch["batch_tag"] = batch_tag
210

211
        # pad chars
212
        max_len_of_sentence = max([len(sentence) for sentence in batch_char])
213
        max_len_of_word = max([max([len(word) for word in sentence]) for sentence in batch_char])
214

215
        padding_word = np.full(max_len_of_word, self.char2id["<PAD>"])
216
        padded_char = []
217

218
        lengths_of_word = []
219

220
        for sentence in batch_char:
221
            padded_sentence = []
222
            length_of_word_in_sentence = []
223

224
            for word in sentence:
225
                length_of_word_in_sentence.append(len(word))
226
                padded_sentence.append(
227
                    np.lib.pad(
228
                        word,
229
                        (0, max_len_of_word - len(word)),
230
                        "constant",
231
                        constant_values=(self.char2id["<PAD>"], self.char2id["<PAD>"]),
232
                    )
233
                )
234

235
            for i in range(max_len_of_sentence - len(padded_sentence)):
236
                padded_sentence.append(padding_word)
237
                length_of_word_in_sentence.append(0)
238

239
            padded_char.append(padded_sentence)
240
            lengths_of_word.append(length_of_word_in_sentence)
241

242
        lengths_of_word = np.array(lengths_of_word)
243

244
        batch["padded_char"] = padded_char
245
        batch["lengths_of_word"] = lengths_of_word
246

247
        return batch, end_idx
248

249
    def train(self, training_file_path, val_file_path, output_model_path=None, nb_epochs=20, init_model_path=None):
250
        raw_train_data = self.read_raw_data(raw_file_path=training_file_path, min_length_of_sentence=2)
251
        raw_val_data = self.read_raw_data(raw_file_path=val_file_path, min_length_of_sentence=2)
252

253
        indexed_train_data = self.index_data(raw_train_data)
254
        indexed_val_data = self.index_data(raw_val_data)
255

256
        saver = tf.train.Saver()
257
        best_f1 = 0
258
        with tf.Session() as sess:
259
            sess.run(tf.global_variables_initializer())
260
            # reload model
261
            if self.pretrained_model_path != "":
262
                print("Model is being reloaded from {:}".format(self.pretrained_model_path))
263
                saver.restore(sess, self.pretrained_model_path + "/model")
264
                _, best_f1 = self.evaluate(sess, indexed_val_data)
265
                print("The best f1: {:06.05}".format(best_f1))
266

267
            # train model
268
            print(self.params)
269
            for epoch in range(nb_epochs):
270
                # shuffle data
271
                random_indexes = np.random.permutation(len(indexed_train_data["indexed_word"]))
272
                data = {}
273
                for i in indexed_train_data:
274
                    data[i] = [indexed_train_data[i][j] for j in random_indexes]
275

276
                losses_of_batches = []
277
                current_idx = 0
278
                while current_idx < len(data["indexed_word"]):
279
                    batch, current_idx = self.get_batch(data, current_idx)
280
                    feed_dict = {
281
                        self.tf_word_ids: batch["padded_word"],
282
                        self.tf_sentence_lengths: batch["real_sentence_lengths"],
283
                        self.tf_labels: batch["padded_tag"],
284
                        self.tf_learning_rate: self.params.learning_rate,
285
                        self.tf_dropout: self.params.dropout,
286
                        self.tf_char_ids: batch["padded_char"],
287
                        self.tf_word_lengths: batch["lengths_of_word"],
288
                        self.tf_raw_word: batch["padded_raw_word"],
289
                    }
290

291
                    _, train_loss = sess.run([self.tf_train_op, self.tf_loss], feed_dict=feed_dict)
292
                    losses_of_batches.append(train_loss)
293

294
                mean_loss = np.mean(losses_of_batches)
295

296
                # evaluate model on the dev set
297
                acc, f1 = self.evaluate(sess, indexed_val_data)
298
                if f1 > best_f1:
299
                    best_f1 = f1
300
                    if output_model_path is not None:
301
                        saver.save(sess, output_model_path)
302
                        print(
303
                            "Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f} (*).".format(
304
                                epoch, mean_loss, acc, f1
305
                            )
306
                        )
307
                    else:
308
                        print(
309
                            "Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f} (*).".format(
310
                                epoch, mean_loss, acc, f1
311
                            )
312
                        )
313
                else:
314
                    print(
315
                        "Epoch {:2d}: Train: mean loss: {:.4f} | Val. set: acc: {:.4f}, f1: {:.4f}".format(
316
                            epoch, mean_loss, acc, f1
317
                        )
318
                    )
319
        print("Training finished.")
320

321
    def evaluate(self, sess, data):
322
        accs = []
323
        correct_preds, total_correct, total_preds = 0.0, 0.0, 0.0
324
        current_idx = 0
325
        while current_idx < len(data["indexed_word"]):
326
            batch, current_idx = self.get_batch(data, current_idx)
327
            # decode using Viterbi algorithm
328
            viterbi_sequences = []
329
            feed_dict = {
330
                self.tf_word_ids: batch["padded_word"],
331
                self.tf_sentence_lengths: batch["real_sentence_lengths"],
332
                self.tf_dropout: 1.0,
333
                self.tf_char_ids: batch["padded_char"],
334
                self.tf_word_lengths: batch["lengths_of_word"],
335
                self.tf_raw_word: batch["padded_raw_word"],
336
            }
337
            _logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)
338

339
            # iterate over the sentences
340
            for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):
341
                # keep only the valid time steps
342
                _logit = _logit[:sequence_length]
343
                viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)
344
                viterbi_sequences += [viterbi_sequence]
345

346
            for lab, lab_pred in zip(batch["batch_tag"], viterbi_sequences):
347
                accs += [a == b for (a, b) in zip(lab, lab_pred)]
348
                lab_chunks = set(helper.get_chunks(lab, self.tag2id))
349
                lab_pred_chunks = set(helper.get_chunks(lab_pred, self.tag2id))
350
                correct_preds += len(lab_chunks & lab_pred_chunks)
351
                total_preds += len(lab_pred_chunks)
352
                total_correct += len(lab_chunks)
353

354
        p = correct_preds / total_preds if correct_preds > 0 else 0
355
        r = correct_preds / total_correct if correct_preds > 0 else 0
356
        f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
357
        acc = np.mean(accs)
358
        return acc, f1
359

360
    def evaluate_using_conlleval(
361
        self, model_path, testing_file_path, output_folder, min_length_of_sentence=0, show_score_file=True
362
    ):
363
        output_file = os.path.join(output_folder, "result.txt")
364
        score_file = os.path.join(output_folder, "score.txt")
365

366
        raw_data = self.read_raw_data(raw_file_path=testing_file_path, min_length_of_sentence=1)
367
        indexed_data = self.index_data(raw_data)
368

369
        f_out = open(output_file, "w", encoding="utf8")
370
        saver = tf.train.Saver()
371
        with tf.Session() as sess:
372
            saver.restore(sess, model_path)
373
            current_idx = 0
374
            while current_idx < len(indexed_data["indexed_word"]):
375
                batch, current_idx = self.get_batch(indexed_data, current_idx)
376

377
                # decode using Viterbi algorithm
378
                viterbi_sequences = []
379
                feed_dict = {
380
                    self.tf_word_ids: batch["padded_word"],
381
                    self.tf_sentence_lengths: batch["real_sentence_lengths"],
382
                    self.tf_dropout: 1.0,
383
                    self.tf_char_ids: batch["padded_char"],
384
                    self.tf_word_lengths: batch["lengths_of_word"],
385
                    self.tf_raw_word: batch["padded_raw_word"],
386
                }
387
                _logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)
388

389
                # iterate over the sentences
390
                for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):
391
                    # keep only the valid time steps
392
                    _logit = _logit[:sequence_length]
393
                    viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)
394
                    viterbi_sequences += [viterbi_sequence]
395

396
                for words, labs, lab_preds in zip(batch["batch_word"], batch["batch_tag"], viterbi_sequences):
397
                    for word, lab, lab_pred in zip(words, labs, lab_preds):
398
                        f_out.write("{:} {:} {:}\n".format(self.id2word[word], self.id2tag[lab], self.id2tag[lab_pred]))
399
                    f_out.write("\n")
400

401
        f_out.close()
402
        os.system('perl "%s" < "%s" > "%s"' % ("conlleval", output_file, score_file))
403
        print("Tagging output and testing results were written to " + output_file + " and " + score_file)
404

405
        if show_score_file:
406
            print("Score on {} calculated by ConllEval:".format(testing_file_path))
407
            f_in = open(score_file, mode="r", encoding="utf8")
408
            for line in f_in.readlines():
409
                print(line)
410

411
    def predict(self, sess, text):
412
        inp_sent = text
413

414
        if inp_sent == "":
415
            return ""
416

417
        for p in [".", "?", "!"]:
418
            if p in inp_sent:
419
                return inp_sent
420

421
        words = word_tokenize(inp_sent)
422

423
        raw_data = {"word": [words]}
424

425
        indexed_data = self.index_data(raw_data)
426

427
        current_idx = 0
428
        while current_idx < len(indexed_data["indexed_word"]):
429
            batch, current_idx = self.get_batch(indexed_data, current_idx)
430

431
            # decode using Viterbi algorithm
432
            viterbi_sequences = []
433
            feed_dict = {
434
                self.tf_word_ids: batch["padded_word"],
435
                self.tf_sentence_lengths: batch["real_sentence_lengths"],
436
                self.tf_dropout: 1.0,
437
                self.tf_char_ids: batch["padded_char"],
438
                self.tf_word_lengths: batch["lengths_of_word"],
439
                self.tf_raw_word: batch["padded_raw_word"],
440
            }
441
            _logits, _transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)
442

443
            # iterate over the sentences
444
            for _logit, sequence_length in zip(_logits, batch["real_sentence_lengths"]):
445
                # keep only the valid time steps
446
                _logit = _logit[:sequence_length]
447
                viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(_logit, _transition_params)
448
                viterbi_sequences += [viterbi_sequence]
449

450
            # pred_labels = [[self.id2tag[t] for t in s] for s in viterbi_sequences]
451
            pred_labels = [self.id2tag[t] for t in viterbi_sequences[0]]
452
            # print("Pred_labels: ",pred_labels)
453

454
            tag2text = {"B-S": ".", "B-Q": "?", "O": "."}
455

456
            punctuation = tag2text[pred_labels[0]]
457
            sent = words[0]
458

459
            for word, tag in zip(words[1:], pred_labels[1:]):
460
                if tag != "O":
461
                    sent += punctuation
462
                    punctuation = tag2text[tag]
463
                sent += " " + word
464
            sent += punctuation
465

466
            return sent
467

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.