paddlenlp

Форк
0
503 строки · 20.2 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import math
16

17
import numpy as np
18
import paddle
19
from paddle.optimizer.lr import LambdaDecay
20

21
from paddlenlp.transformers import normalize_chars, tokenize_special_chars
22

23

24
def create_dataloader(dataset, mode="train", batch_size=1, batchify_fn=None, trans_fn=None):
25
    if trans_fn:
26
        dataset = dataset.map(trans_fn)
27

28
    shuffle = True if mode == "train" else False
29
    if mode == "train":
30
        batch_sampler = paddle.io.DistributedBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)
31
    else:
32
        batch_sampler = paddle.io.BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)
33

34
    return paddle.io.DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)
35

36

37
class LinearDecayWithWarmup(LambdaDecay):
38
    def __init__(self, learning_rate, total_steps, warmup, last_epoch=-1, verbose=False):
39
        """
40
        Creates a learning rate scheduler, which increases learning rate linearly
41
        from 0 to given `learning_rate`, after this warmup period learning rate
42
        would be decreased linearly from the base learning rate to 0.
43

44
        Args:
45
            learning_rate (float):
46
                The base learning rate. It is a python float number.
47
            total_steps (int):
48
                The number of training steps.
49
            warmup (int or float):
50
                If int, it means the number of steps for warmup. If float, it means
51
                the proportion of warmup in total training steps.
52
            last_epoch (int, optional):
53
                The index of last epoch. It can be set to restart training. If
54
                None, it means initial learning rate.
55
                Defaults to -1.
56
            verbose (bool, optional):
57
                If True, prints a message to stdout for each update.
58
                Defaults to False.
59
        """
60

61
        warmup_steps = warmup if isinstance(warmup, int) else int(math.floor(warmup * total_steps))
62

63
        def lr_lambda(current_step):
64
            if current_step < warmup_steps:
65
                return float(current_step) / float(max(1, warmup_steps))
66
            return max(0.0, 1.0 - current_step / total_steps)
67

68
        super(LinearDecayWithWarmup, self).__init__(learning_rate, lr_lambda, last_epoch, verbose)
69

70

71
def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
72
    """
73
    Builds model inputs from a sequence or a pair of sequences for sequence
74
    classification tasks by concatenating and adding special tokens. And
75
    creates a mask from the two sequences for sequence-pair classification
76
    tasks.
77

78
    The convention in Electra/EHealth is:
79

80
    - single sequence:
81
        input_ids:      ``[CLS] X [SEP]``
82
        token_type_ids: ``  0   0   0``
83
        position_ids:   ``  0   1   2``
84

85
    - a senquence pair:
86
        input_ids:      ``[CLS] X [SEP] Y [SEP]``
87
        token_type_ids: ``  0   0   0   1   1``
88
        position_ids:   ``  0   1   2   3   4``
89

90
    Args:
91
        example (obj:`dict`):
92
            A dictionary of input data, containing text and label if it has.
93
        tokenizer (obj:`PretrainedTokenizer`):
94
            A tokenizer inherits from :class:`paddlenlp.transformers.PretrainedTokenizer`.
95
            Users can refer to the superclass for more information.
96
        max_seq_length (obj:`int`):
97
            The maximum total input sequence length after tokenization.
98
            Sequences longer will be truncated, and the shorter will be padded.
99
        is_test (obj:`bool`, default to `False`):
100
            Whether the example contains label or not.
101

102
    Returns:
103
        input_ids (obj:`list[int]`):
104
            The list of token ids.
105
        token_type_ids (obj:`list[int]`):
106
            List of sequence pair mask.
107
        position_ids (obj:`list[int]`):
108
            List of position ids.
109
        label(obj:`numpy.array`, data type of int64, optional):
110
            The input label if not is_test.
111
    """
112
    text_a = example["text_a"]
113
    text_b = example.get("text_b", None)
114

115
    text_a = tokenize_special_chars(normalize_chars(text_a))
116
    if text_b is not None:
117
        text_b = tokenize_special_chars(normalize_chars(text_b))
118

119
    encoded_inputs = tokenizer(text=text_a, text_pair=text_b, max_seq_len=max_seq_length, return_position_ids=True)
120
    input_ids = encoded_inputs["input_ids"]
121
    token_type_ids = encoded_inputs["token_type_ids"]
122
    position_ids = encoded_inputs["position_ids"]
123

124
    if is_test:
125
        return input_ids, token_type_ids, position_ids
126
    label = np.array([example["label"]], dtype="int64")
127
    return input_ids, token_type_ids, position_ids, label
128

129

130
def convert_example_ner(example, tokenizer, max_seq_length=512, pad_label_id=-100, is_test=False):
131
    """
132
    Builds model inputs from a sequence and creates labels for named-
133
    entity recognition task CMeEE.
134

135
    For example, a sample should be:
136

137
    - input_ids:      ``[CLS]  x1   x2 [SEP] [PAD]``
138
    - token_type_ids: ``  0    0    0    0     0``
139
    - position_ids:   ``  0    1    2    3     0``
140
    - attention_mask: ``  1    1    1    1     0``
141
    - label_oth:      `` 32    3   32   32    32`` (optional, label ids of others)
142
    - label_sym:      ``  4    4    4    4     4`` (optional, label ids of symptom)
143

144
    Args:
145
        example (obj:`dict`):
146
            A dictionary of input data, containing text and label if it has.
147
        tokenizer (obj:`PretrainedTokenizer`):
148
            A tokenizer inherits from :class:`paddlenlp.transformers.PretrainedTokenizer`.
149
            Users can refer to the superclass for more information.
150
        max_seq_length (obj:`int`):
151
            The maximum total input sequence length after tokenization.
152
            Sequences longer will be truncated, and the shorter will be padded.
153
        is_test (obj:`bool`, default to `False`):
154
            Whether the example contains label or not.
155

156
    Returns:
157
        encoded_output (obj: `dict[str, list|np.array]`):
158
            The sample dictionary including `input_ids`, `token_type_ids`,
159
            `position_ids`, `attention_mask`, `label_oth` (optional),
160
            `label_sym` (optional)
161
    """
162

163
    encoded_inputs = {}
164
    text = example["text"]
165
    if len(text) > max_seq_length - 2:
166
        text = text[: max_seq_length - 2]
167
    text = ["[CLS]"] + [x.lower() for x in text] + ["[SEP]"]
168
    input_len = len(text)
169
    encoded_inputs["input_ids"] = tokenizer.convert_tokens_to_ids(text)
170
    encoded_inputs["token_type_ids"] = np.zeros(input_len)
171
    encoded_inputs["position_ids"] = list(range(input_len))
172
    encoded_inputs["attention_mask"] = np.ones(input_len)
173

174
    if not is_test:
175
        labels = example["labels"]
176
        if input_len - 2 < len(labels[0]):
177
            labels[0] = labels[0][: input_len - 2]
178
        if input_len - 2 < len(labels[1]):
179
            labels[1] = labels[1][: input_len - 2]
180
        encoded_inputs["label_oth"] = [pad_label_id[0]] + labels[0] + [pad_label_id[0]]
181
        encoded_inputs["label_sym"] = [pad_label_id[1]] + labels[1] + [pad_label_id[1]]
182

183
    return encoded_inputs
184

185

186
def convert_example_spo(example, tokenizer, num_classes, max_seq_length=512, is_test=False):
187
    """
188
    Builds model inputs from a sequence and creates labels for SPO prediction
189
    task CMeIE.
190

191
    For example, a sample should be:
192

193
    - input_ids:      ``[CLS]  x1   x2 [SEP] [PAD]``
194
    - token_type_ids: ``  0    0    0    0     0``
195
    - position_ids:   ``  0    1    2    3     0``
196
    - attention_mask: ``  1    1    1    1     0``
197
    - ent_label:      ``[[0    1    0    0     0], # start ids are set as 1
198
                         [0    0    1    0     0]] # end ids are set as 1
199
    - spo_label: a tensor of shape [num_classes, max_batch_len, max_batch_len].
200
                 Set [predicate_id, subject_start_id, object_start_id] as 1
201
                 when (subject, predicate, object) exists.
202

203
    Args:
204
        example (obj:`dict`):
205
            A dictionary of input data, containing text and label if it has.
206
        tokenizer (obj:`PretrainedTokenizer`):
207
            A tokenizer inherits from :class:`paddlenlp.transformers.PretrainedTokenizer`.
208
            Users can refer to the superclass for more information.
209
        num_classes (obj:`int`):
210
            The number of predicates.
211
        max_seq_length (obj:`int`):
212
            The maximum total input sequence length after tokenization.
213
            Sequences longer will be truncated, and the shorter will be padded.
214
        is_test (obj:`bool`, default to `False`):
215
            Whether the example contains label or not.
216

217
    Returns:
218
        encoded_output (obj: `dict[str, list|np.array]`):
219
            The sample dictionary including `input_ids`, `token_type_ids`,
220
            `position_ids`, `attention_mask`, `ent_label` (optional),
221
            `spo_label` (optional)
222
    """
223
    encoded_inputs = {}
224
    text = example["text"]
225
    if len(text) > max_seq_length - 2:
226
        text = text[: max_seq_length - 2]
227
    text = ["[CLS]"] + [x.lower() for x in text] + ["[SEP]"]
228
    input_len = len(text)
229
    encoded_inputs["input_ids"] = tokenizer.convert_tokens_to_ids(text)
230
    encoded_inputs["token_type_ids"] = np.zeros(input_len)
231
    encoded_inputs["position_ids"] = list(range(input_len))
232
    encoded_inputs["attention_mask"] = np.ones(input_len)
233
    if not is_test:
234
        encoded_inputs["ent_label"] = example["ent_label"]
235
        encoded_inputs["spo_label"] = example["spo_label"]
236
    return encoded_inputs
237

238

239
class NERChunkEvaluator(paddle.metric.Metric):
240
    """
241
    NERChunkEvaluator computes the precision, recall and F1-score for chunk detection.
242
    It is often used in sequence tagging tasks, such as Named Entity Recognition (NER).
243

244
    Args:
245
        label_list (list):
246
            The label list.
247

248
    Note:
249
        Difference from `paddlenlp.metric.ChunkEvaluator`:
250

251
        - `paddlenlp.metric.ChunkEvaluator`
252
           All sequences with non-'O' labels are taken as chunks when computing num_infer.
253
        - `NERChunkEvaluator`
254
           Only complete sequences are taken as chunks, namely `B- I- E-` or `S-`.
255
    """
256

257
    def __init__(self, label_list):
258
        super(NERChunkEvaluator, self).__init__()
259
        self.id2label = [dict(enumerate(x)) for x in label_list]
260
        self.num_classes = [len(x) for x in label_list]
261
        self.num_infer = 0
262
        self.num_label = 0
263
        self.num_correct = 0
264

265
    def compute(self, lengths, predictions, labels):
266
        """
267
        Computes the prediction, recall and F1-score for chunk detection.
268

269
        Args:
270
            lengths (Tensor):
271
                The valid length of every sequence, a tensor with shape `[batch_size]`.
272
            predictions (Tensor):
273
                The predictions index, a tensor with shape `[batch_size, sequence_length]`.
274
            labels (Tensor):
275
                The labels index, a tensor with shape `[batch_size, sequence_length]`.
276

277
        Returns:
278
            tuple: Returns tuple (`num_infer_chunks, num_label_chunks, num_correct_chunks`).
279

280
            With the fields:
281

282
            - `num_infer_chunks` (Tensor): The number of the inference chunks.
283
            - `num_label_chunks` (Tensor): The number of the label chunks.
284
            - `num_correct_chunks` (Tensor): The number of the correct chunks.
285
        """
286
        assert len(predictions) == len(labels)
287
        assert len(predictions) == len(self.id2label)
288
        preds = [x.numpy() for x in predictions]
289
        labels = [x.numpy() for x in labels]
290

291
        preds_chunk = set()
292
        label_chunk = set()
293
        for idx, (pred, label) in enumerate(zip(preds, labels)):
294
            for i, case in enumerate(pred):
295
                case = [self.id2label[idx][x] for x in case[: lengths[i]]]
296
                preds_chunk |= self.extract_chunk(case, i)
297
            for i, case in enumerate(label):
298
                case = [self.id2label[idx][x] for x in case[: lengths[i]]]
299
                label_chunk |= self.extract_chunk(case, i)
300

301
        num_infer = len(preds_chunk)
302
        num_label = len(label_chunk)
303
        num_correct = len(preds_chunk & label_chunk)
304
        return num_infer, num_label, num_correct
305

306
    def update(self, correct):
307
        num_infer, num_label, num_correct = correct
308
        self.num_infer += num_infer
309
        self.num_label += num_label
310
        self.num_correct += num_correct
311

312
    def accumulate(self):
313
        precision = self.num_correct / (self.num_infer + 1e-6)
314
        recall = self.num_correct / (self.num_label + 1e-6)
315
        f1 = 2 * precision * recall / (precision + recall + 1e-6)
316
        return precision, recall, f1
317

318
    def reset(self):
319
        self.num_infer = 0
320
        self.num_label = 0
321
        self.num_correct = 0
322

323
    def name(self):
324
        return "precision", "recall", "f1"
325

326
    def extract_chunk(self, sequence, cid=0):
327
        chunks = set()
328

329
        start_idx, cur_idx = 0, 0
330
        while cur_idx < len(sequence):
331
            if sequence[cur_idx][0] == "B":
332
                start_idx = cur_idx
333
                cur_idx += 1
334
                while cur_idx < len(sequence) and sequence[cur_idx][0] == "I":
335
                    if sequence[cur_idx][2:] == sequence[start_idx][2:]:
336
                        cur_idx += 1
337
                    else:
338
                        break
339
                if cur_idx < len(sequence) and sequence[cur_idx][0] == "E":
340
                    if sequence[cur_idx][2:] == sequence[start_idx][2:]:
341
                        chunks.add((cid, sequence[cur_idx][2:], start_idx, cur_idx))
342
                        cur_idx += 1
343
            elif sequence[cur_idx][0] == "S":
344
                chunks.add((cid, sequence[cur_idx][2:], cur_idx, cur_idx))
345
                cur_idx += 1
346
            else:
347
                cur_idx += 1
348

349
        return chunks
350

351

352
class SPOChunkEvaluator(paddle.metric.Metric):
353
    """
354
    SPOChunkEvaluator computes the precision, recall and F1-score for multiple
355
    chunk detections, including Named Entity Recognition (NER) and SPO Prediction.
356

357
    Args:
358
        num_classes (int):
359
            The number of predicates.
360
    """
361

362
    def __init__(self, num_classes=None):
363
        super(SPOChunkEvaluator, self).__init__()
364
        self.num_classes = num_classes
365
        self.num_infer_ent = 0
366
        self.num_infer_spo = 1e-10
367
        self.num_label_ent = 0
368
        self.num_label_spo = 1e-10
369
        self.num_correct_ent = 0
370
        self.num_correct_spo = 0
371

372
    def compute(self, lengths, ent_preds, spo_preds, ent_labels, spo_labels):
373
        """
374
        Computes the prediction, recall and F1-score for NER and SPO prediction.
375

376
        Args:
377
            lengths (Tensor):
378
                The valid length of every sequence, a tensor with shape `[batch_size]`.
379
            ent_preds (Tensor):
380
                The predictions of entities.
381
                A tensor with shape `[batch_size, sequence_length, 2]`.
382
                `ent_preds[:, :, 0]` denotes the start indexes of entities.
383
                `ent_preds[:, :, 1]` denotes the end indexes of entities.
384
            spo_preds (Tensor):
385
                The predictions of predicates between all possible entities.
386
                A tensor with shape `[batch_size, num_classes, sequence_length, sequence_length]`.
387
            ent_labels (list[list|tuple]):
388
                The entity labels' indexes. A list of pair `[start_index, end_index]`.
389
            spo_labels (list[list|tuple]):
390
                The SPO labels' indexes. A list of triple `[[subject_start_index, subject_end_index],
391
                predicate_id, [object_start_index, object_end_index]]`.
392

393
        Returns:
394
            tuple:
395
                Returns tuple (`num_infer_chunks, num_label_chunks, num_correct_chunks`).
396
                The `ent` denotes results of NER and the `spo` denotes results of SPO prediction.
397

398
            With the fields:
399

400
            - `num_infer_chunks` (dict): The number of the inference chunks.
401
            - `num_label_chunks` (dict): The number of the label chunks.
402
            - `num_correct_chunks` (dict): The number of the correct chunks.
403
        """
404
        ent_preds = ent_preds.numpy()
405
        spo_preds = spo_preds.numpy()
406

407
        ent_pred_list = []
408
        ent_idxs_list = []
409
        for idx, ent_pred in enumerate(ent_preds):
410
            seq_len = lengths[idx] - 2
411
            start = np.where(ent_pred[:, 0] > 0.5)[0]
412
            end = np.where(ent_pred[:, 1] > 0.5)[0]
413
            ent_pred = []
414
            ent_idxs = {}
415
            for x in start:
416
                y = end[end >= x]
417
                if (x == 0) or (x > seq_len):
418
                    continue
419
                if len(y) > 0:
420
                    y = y[0]
421
                    if y > seq_len:
422
                        continue
423
                    ent_idxs[x] = (x - 1, y - 1)
424
                    ent_pred.append((x - 1, y - 1))
425
            ent_pred_list.append(ent_pred)
426
            ent_idxs_list.append(ent_idxs)
427

428
        spo_preds = spo_preds > 0
429
        spo_pred_list = [[] for _ in range(len(spo_preds))]
430
        idxs, preds, subs, objs = np.nonzero(spo_preds)
431
        for idx, p_id, s_id, o_id in zip(idxs, preds, subs, objs):
432
            obj = ent_idxs_list[idx].get(o_id, None)
433
            if obj is None:
434
                continue
435
            sub = ent_idxs_list[idx].get(s_id, None)
436
            if sub is None:
437
                continue
438
            spo_pred_list[idx].append((sub, p_id, obj))
439

440
        correct = {"ent": 0, "spo": 0}
441
        infer = {"ent": 0, "spo": 0}
442
        label = {"ent": 0, "spo": 0}
443
        for ent_pred, ent_true in zip(ent_pred_list, ent_labels):
444
            ent_true = [tuple(x) for x in ent_true]
445
            infer["ent"] += len(set(ent_pred))
446
            label["ent"] += len(set(ent_true))
447
            correct["ent"] += len(set(ent_pred) & set(ent_true))
448

449
        for spo_pred, spo_true in zip(spo_pred_list, spo_labels):
450
            spo_true = [(tuple(s), p, tuple(o)) for s, p, o in spo_true]
451
            infer["spo"] += len(set(spo_pred))
452
            label["spo"] += len(set(spo_true))
453
            correct["spo"] += len(set(spo_pred) & set(spo_true))
454

455
        return infer, label, correct
456

457
    def update(self, corrects):
458
        assert len(corrects) == 3
459
        for item in corrects:
460
            assert isinstance(item, dict)
461
            for value in item.values():
462
                if not self._is_number_or_matrix(value):
463
                    raise ValueError("The numbers must be a number(int) or a numpy ndarray.")
464
        num_infer, num_label, num_correct = corrects
465
        self.num_infer_ent += num_infer["ent"]
466
        self.num_infer_spo += num_infer["spo"]
467
        self.num_label_ent += num_label["ent"]
468
        self.num_label_spo += num_label["spo"]
469
        self.num_correct_ent += num_correct["ent"]
470
        self.num_correct_spo += num_correct["spo"]
471

472
    def accumulate(self):
473
        spo_precision = self.num_correct_spo / self.num_infer_spo
474
        spo_recall = self.num_correct_spo / self.num_label_spo
475
        spo_f1 = 2 * self.num_correct_spo / (self.num_infer_spo + self.num_label_spo)
476
        ent_precision = self.num_correct_ent / self.num_infer_ent if self.num_infer_ent > 0 else 0.0
477
        ent_recall = self.num_correct_ent / self.num_label_ent if self.num_label_ent > 0 else 0.0
478
        ent_f1 = (
479
            2 * ent_precision * ent_recall / (ent_precision + ent_recall) if (ent_precision + ent_recall) != 0 else 0.0
480
        )
481
        return {"entity": (ent_precision, ent_recall, ent_f1), "spo": (spo_precision, spo_recall, spo_f1)}
482

483
    def _is_number_or_matrix(self, var):
484
        def _is_number_(var):
485
            return (
486
                isinstance(var, int)
487
                or isinstance(var, np.int64)
488
                or isinstance(var, float)
489
                or (isinstance(var, np.ndarray) and var.shape == (1,))
490
            )
491

492
        return _is_number_(var) or isinstance(var, np.ndarray)
493

494
    def reset(self):
495
        self.num_infer_ent = 0
496
        self.num_infer_spo = 1e-10
497
        self.num_label_ent = 0
498
        self.num_label_spo = 1e-10
499
        self.num_correct_ent = 0
500
        self.num_correct_spo = 0
501

502
    def name(self):
503
        return {"entity": ("precision", "recall", "f1"), "spo": ("precision", "recall", "f1")}
504

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.