paddlenlp

Форк
0
759 строк · 31.1 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# flake8: noqa
16
"""Tokenization classes for OpenAI GPT."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import json
21
import logging
22
import os
23
import sys
24
import warnings
25
from io import open
26

27
import regex as re
28
from ppfleetx.utils.download import cached_path
29

30
try:
31
    from functools import lru_cache
32
except ImportError:
33
    # Just a dummy decorator to get the checks to run on python2
34
    # because honestly I don't want to support a byte-level unicode BPE
35
    # tokenizer on python 2 right now.
36
    def lru_cache():
37
        return lambda func: func
38

39

40
from ppfleetx.utils.log import logger
41

42
try:
43
    import paddlenlp
44
    from paddlenlp.transformers.gpt.tokenizer import GPTChineseTokenizer
45
except ImportError:
46
    raise ImportError("Please import paddlenlp before running the GPT tasks.")
47

48
PRETRAINED_VOCAB_ARCHIVE_MAP = {
49
    "gpt2": "http://fleet.bj.bcebos.com/datasets/gpt/gpt2-vocab.json",
50
}
51
PRETRAINED_MERGES_ARCHIVE_MAP = {
52
    "gpt2": "http://fleet.bj.bcebos.com/datasets/gpt/gpt2-merges.txt",
53
}
54
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
55
    "gpt2": 1024,
56
}
57
VOCAB_NAME = "vocab.json"
58
MERGES_NAME = "merges.txt"
59
SPECIAL_TOKENS_NAME = "special_tokens.txt"
60

61

62
@lru_cache()
63
def bytes_to_unicode():
64
    """
65
    Returns list of utf-8 byte and a corresponding list of unicode strings.
66
    The reversible bpe codes work on unicode strings.
67
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
68
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
69
    This is a signficant percentage of your normal, say, 32K bpe vocab.
70
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
71
    And avoids mapping to whitespace/control characters the bpe code barfs on.
72
    """
73
    _chr = unichr if sys.version_info[0] == 2 else chr
74
    bs = (
75
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
76
    )
77
    cs = bs[:]
78
    n = 0
79
    for b in range(2**8):
80
        if b not in bs:
81
            bs.append(b)
82
            cs.append(2**8 + n)
83
            n += 1
84
    cs = [_chr(n) for n in cs]
85
    return dict(zip(bs, cs))
86

87

88
def get_pairs(word):
89
    """Return set of symbol pairs in a word.
90
    Word is represented as tuple of symbols (symbols being variable-length strings).
91
    """
92
    pairs = set()
93
    prev_char = word[0]
94
    for char in word[1:]:
95
        pairs.add((prev_char, char))
96
        prev_char = char
97
    return pairs
98

99

100
class GPTTokenizer(object):
101
    """
102
    GPT-2 BPE tokenizer. Peculiarities:
103
        - Byte-level BPE
104
    """
105

106
    padding_side = "right"
107
    truncation_side = "right"
108
    model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
109
    pad_token_type_id = 0
110
    pad_token_id = 0
111

112
    @classmethod
113
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
114
        """
115
        Instantiate a PreTrainedBertModel from a pre-trained model file.
116
        Download and cache the pre-trained model file if needed.
117
        """
118
        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
119
            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
120
            merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
121
            special_tokens_file = None
122
        else:
123
            vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
124
            merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
125
            special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
126
            if not os.path.exists(special_tokens_file):
127
                special_tokens_file = None
128
            else:
129
                logger.info("loading special tokens file {}".format(special_tokens_file))
130
        # redirect to the cache, if necessary
131
        try:
132
            resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
133
            resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
134
        except Exception as e:
135
            logger.info(e)
136
            logger.error(
137
                "Model name '{}' was not found in model name list ({}). "
138
                "We assumed '{}' was a path or url but couldn't find files {} and {} "
139
                "at this path or url.".format(
140
                    pretrained_model_name_or_path,
141
                    ", ".join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
142
                    pretrained_model_name_or_path,
143
                    vocab_file,
144
                    merges_file,
145
                )
146
            )
147
            return None
148
        if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
149
            logger.info("loading vocabulary file {}".format(vocab_file))
150
            logger.info("loading merges file {}".format(merges_file))
151
        else:
152
            logger.info("loading vocabulary file {} from cache at {}".format(vocab_file, resolved_vocab_file))
153
            logger.info("loading merges file {} from cache at {}".format(merges_file, resolved_merges_file))
154
        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
155
            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
156
            # than the number of positional embeddings
157
            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
158
            kwargs["max_len"] = min(kwargs.get("max_len", int(1e12)), max_len)
159
        # Instantiate tokenizer.
160
        if special_tokens_file and "special_tokens" not in kwargs:
161
            special_tokens = open(special_tokens_file, encoding="utf-8").read().split("\n")[:-1]
162
        else:
163
            special_tokens = kwargs.pop("special_tokens", [])
164
        tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
165
        return tokenizer
166

167
    def __init__(self, vocab_file, merges_file, errors="replace", special_tokens=None, max_len=None, **kwargs):
168

169
        self.padding_side = kwargs.pop("padding_side", self.padding_side)
170
        if self.padding_side not in ["right", "left"]:
171
            raise ValueError(
172
                f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
173
            )
174

175
        self.truncation_side = kwargs.pop("truncation_side", self.truncation_side)
176
        if self.truncation_side not in ["right", "left"]:
177
            raise ValueError(
178
                f"Padding side should be selected between 'right' and 'left', current value: {self.truncation_side}"
179
            )
180

181
        self.max_len = max_len if max_len is not None else int(1e12)
182
        self.encoder = json.load(open(vocab_file))
183
        self.decoder = {v: k for k, v in self.encoder.items()}
184
        self.errors = errors  # how to handle errors in decoding
185
        self.byte_encoder = bytes_to_unicode()
186
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
187
        bpe_data = open(merges_file, encoding="utf-8").read().split("\n")[1:-1]
188
        bpe_merges = [tuple(merge.split()) for merge in bpe_data]
189
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
190
        self.cache = {}
191

192
        # Should haved added re.IGNORECASE so BPE merges can happen for
193
        # capitalized versions of contractions
194
        self.eod_id = self.encoder["<|endoftext|>"]
195
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
196

197
        self.special_tokens = {}
198
        self.special_tokens_decoder = {}
199
        self.set_special_tokens(special_tokens)
200

201
    def __call__(
202
        self,
203
        text,
204
        text_pair=None,
205
        add_special_tokens=True,
206
        padding=False,
207
        truncation=False,
208
        max_length=None,
209
        pad_to_multiple_of=None,
210
        return_token_type_ids=None,
211
        return_attention_mask=None,
212
        return_overflowing_tokens=False,
213
        return_length=False,
214
    ):
215
        assert padding in [True, False, "longest", "max_length", "do_not_pad"]
216

217
        if max_length is not None and padding is False and truncation is False:
218
            truncation = "longest_first"
219

220
        if padding is True:
221
            padding = "longest"
222
        elif padding is False:
223
            padding = "do_not_pad"
224

225
        assert truncation in [True, False, "only_first", "only_second", "longest_first", "do_not_truncate"]
226
        if truncation is True:
227
            truncation = "longest_first"
228
        elif truncation is False:
229
            truncation = "do_not_truncate"
230

231
        # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
232
        if (
233
            truncation != "do_not_truncate"
234
            and padding != "do_not_pad"
235
            and pad_to_multiple_of is not None
236
            and max_length is not None
237
            and (max_length % pad_to_multiple_of != 0)
238
        ):
239
            raise ValueError(
240
                "Truncation and padding are both activated but "
241
                f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
242
            )
243

244
        is_batched = isinstance(text, (list, tuple))
245
        if is_batched:
246
            raise NotImplementedError
247
        else:
248
            return self.encode_plus(
249
                text=text,
250
                text_pair=text_pair,
251
                add_special_tokens=add_special_tokens,
252
                padding=padding,
253
                truncation=truncation,
254
                max_length=max_length,
255
                pad_to_multiple_of=pad_to_multiple_of,
256
                return_token_type_ids=return_token_type_ids,
257
                return_attention_mask=return_attention_mask,
258
                return_overflowing_tokens=return_overflowing_tokens,
259
                return_length=return_length,
260
            )
261

262
    def encode_plus(
263
        self,
264
        text,
265
        text_pair,
266
        add_special_tokens=True,
267
        padding="do_not_pad",
268
        truncation="do_not_truncate",
269
        max_length=None,
270
        pad_to_multiple_of=None,
271
        return_token_type_ids=None,
272
        return_attention_mask=None,
273
        return_overflowing_tokens=False,
274
        return_length=False,
275
        **kwargs
276
    ):
277
        def get_input_ids(text):
278
            if isinstance(text, str):
279
                tokens = self.tokenize(text, **kwargs)
280
                return self.convert_tokens_to_ids(tokens)
281
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
282
                if is_split_into_words:
283
                    tokens = list(
284
                        itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
285
                    )
286
                    return self.convert_tokens_to_ids(tokens)
287
                else:
288
                    return self.convert_tokens_to_ids(text)
289
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
290
                return text
291
            else:
292
                raise NotImplementedError
293

294
        first_ids = get_input_ids(text)
295
        second_ids = get_input_ids(text_pair) if text_pair is not None else None
296

297
        pair = bool(second_ids is not None)
298
        len_ids = len(first_ids)
299
        len_pair_ids = len(second_ids) if pair else 0
300

301
        if return_token_type_ids and not add_special_tokens:
302
            raise ValueError(
303
                "Asking to return token_type_ids while setting add_special_tokens to False "
304
                "results in an undefined behavior. Please set add_special_tokens to True or "
305
                "set return_token_type_ids to None."
306
            )
307

308
        # Load from model defaults
309
        if return_token_type_ids is None:
310
            return_token_type_ids = "token_type_ids" in self.model_input_names
311
        if return_attention_mask is None:
312
            return_attention_mask = "attention_mask" in self.model_input_names
313

314
        encoded_inputs = {}
315
        # Compute the total size of the returned encodings
316
        total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
317

318
        # Truncation: Handle max sequence length
319
        overflowing_tokens = []
320
        if truncation != "do_not_truncate" and max_length and total_len > max_length:
321
            first_ids, second_ids, overflowing_tokens = self.truncate_sequences(
322
                first_ids,
323
                pair_ids=second_ids,
324
                num_tokens_to_remove=total_len - max_length,
325
                truncation=truncation,
326
            )
327
        if return_overflowing_tokens:
328
            encoded_inputs["overflowing_tokens"] = overflowing_tokens
329
            encoded_inputs["num_truncated_tokens"] = total_len - max_length
330

331
        # Add special tokens
332
        if add_special_tokens:
333
            sequence = self.build_inputs_with_special_tokens(first_ids, second_ids)
334
            token_type_ids = self.create_token_type_ids_from_sequences(first_ids, second_ids)
335
        else:
336
            sequence = first_ids + second_ids if pair else first_ids
337
            token_type_ids = [0] * len(first_ids) + ([0] * len(second_ids) if pair else [])
338

339
        # Build output dictionary
340
        encoded_inputs["input_ids"] = sequence
341
        if return_token_type_ids:
342
            encoded_inputs["token_type_ids"] = token_type_ids
343

344
        # Padding
345
        if padding != "do_not_pad" or return_attention_mask:
346
            encoded_inputs = self.pad(
347
                encoded_inputs,
348
                max_length=max_length,
349
                padding=padding,
350
                pad_to_multiple_of=pad_to_multiple_of,
351
                return_attention_mask=return_attention_mask,
352
            )
353

354
        if return_length:
355
            encoded_inputs["length"] = len(encoded_inputs["input_ids"])
356

357
        return encoded_inputs
358

359
    def num_special_tokens_to_add(self, pair: bool = False) -> int:
360
        token_ids_0 = []
361
        token_ids_1 = []
362
        return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
363

364
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
365
        if token_ids_1 is None:
366
            return token_ids_0
367
        return token_ids_0 + token_ids_1
368

369
    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
370
        if token_ids_1 is None:
371
            return len(token_ids_0) * [0]
372
        return [0] * len(token_ids_0) + [1] * len(token_ids_1)
373

374
    def truncate_sequences(
375
        self,
376
        ids,
377
        pair_ids=None,
378
        num_tokens_to_remove=0,
379
        truncation="longest_first",
380
        stride=0,
381
    ):
382
        if num_tokens_to_remove <= 0:
383
            return ids, pair_ids, []
384

385
        overflowing_tokens = []
386
        if truncation == "only_first" or (truncation == "longest_first" and pair_ids is None):
387
            if len(ids) > num_tokens_to_remove:
388
                window_len = min(len(ids), stride + num_tokens_to_remove)
389
                if self.truncation_side == "left":
390
                    overflowing_tokens = ids[:window_len]
391
                    ids = ids[num_tokens_to_remove:]
392
                elif self.truncation_side == "right":
393
                    overflowing_tokens = ids[-window_len:]
394
                    ids = ids[:-num_tokens_to_remove]
395
                else:
396
                    raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
397

398
            else:
399
                error_msg = (
400
                    f"We need to remove {num_tokens_to_remove} to truncate the input "
401
                    f"but the first sequence has a length {len(ids)}. "
402
                )
403
                if truncation == "only_first":
404
                    error_msg = (
405
                        error_msg + "Please select another truncation strategy than "
406
                        f"{truncation}, for instance 'longest_first' or 'only_second'."
407
                    )
408
                logger.error(error_msg)
409
        elif truncation == "longest_first":
410
            warnings.warn(
411
                "Be aware, overflowing tokens are not returned for the setting you have chosen,"
412
                f" i.e. sequence pairs with the '{truncation}' "
413
                "truncation strategy. So the returned list will always be empty even if some "
414
                "tokens have been removed."
415
            )
416
            for _ in range(num_tokens_to_remove):
417
                if pair_ids is None or len(ids) > len(pair_ids):
418
                    if self.truncation_side == "right":
419
                        ids = ids[:-1]
420
                    elif self.truncation_side == "left":
421
                        ids = ids[1:]
422
                    else:
423
                        raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
424
                else:
425
                    if self.truncation_side == "right":
426
                        pair_ids = pair_ids[:-1]
427
                    elif self.truncation_side == "left":
428
                        pair_ids = pair_ids[1:]
429
                    else:
430
                        raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
431
        elif truncation == "only_second" and pair_ids is not None:
432
            if len(pair_ids) > num_tokens_to_remove:
433
                window_len = min(len(pair_ids), stride + num_tokens_to_remove)
434
                if self.truncation_side == "right":
435
                    overflowing_tokens = pair_ids[-window_len:]
436
                    pair_ids = pair_ids[:-num_tokens_to_remove]
437
                elif self.truncation_side == "left":
438
                    overflowing_tokens = pair_ids[:window_len]
439
                    pair_ids = pair_ids[num_tokens_to_remove:]
440
                else:
441
                    raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
442
            else:
443
                logger.error(
444
                    f"We need to remove {num_tokens_to_remove} to truncate the input "
445
                    f"but the second sequence has a length {len(pair_ids)}. "
446
                    f"Please select another truncation strategy than {truncation}, "
447
                    "for instance 'longest_first' or 'only_first'."
448
                )
449

450
        return (ids, pair_ids, overflowing_tokens)
451

452
    def pad(
453
        self,
454
        encoded_inputs,
455
        padding=True,
456
        max_length=None,
457
        pad_to_multiple_of=None,
458
        return_attention_mask=None,
459
        return_tensors=None,
460
        verbose=True,
461
    ):
462

463
        # The model's main input name, usually `input_ids`, has be passed for padding
464
        if self.model_input_names[0] not in encoded_inputs:
465
            raise ValueError(
466
                "You should supply an encoding or a list of encodings to this method "
467
                f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
468
            )
469

470
        required_input = encoded_inputs[self.model_input_names[0]]
471

472
        if not required_input:
473
            if return_attention_mask:
474
                encoded_inputs["attention_mask"] = []
475
            return encoded_inputs
476

477
        required_input = encoded_inputs[self.model_input_names[0]]
478

479
        if required_input and not isinstance(required_input[0], (list, tuple)):
480
            encoded_inputs = self._pad(
481
                encoded_inputs,
482
                max_length=max_length,
483
                padding=padding,
484
                pad_to_multiple_of=pad_to_multiple_of,
485
                return_attention_mask=return_attention_mask,
486
            )
487
            return encoded_inputs
488

489
        batch_size = len(required_input)
490
        assert all(
491
            len(v) == batch_size for v in encoded_inputs.values()
492
        ), "Some items in the output dictionary have a different batch size than others."
493

494
        if padding == "longest":
495
            max_length = max(len(inputs) for inputs in required_input)
496
            padding = "max_length"
497

498
        batch_outputs = {}
499
        for i in range(batch_size):
500
            inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
501
            outputs = self._pad(
502
                inputs,
503
                max_length=max_length,
504
                padding=padding,
505
                pad_to_multiple_of=pad_to_multiple_of,
506
                return_attention_mask=return_attention_mask,
507
            )
508

509
            for key, value in outputs.items():
510
                if key not in batch_outputs:
511
                    batch_outputs[key] = []
512
                batch_outputs[key].append(value)
513

514
        return encoded_inputs
515

516
    def _pad(
517
        self,
518
        encoded_inputs,
519
        max_length=None,
520
        padding="do_not_pad",
521
        pad_to_multiple_of=None,
522
        return_attention_mask=None,
523
    ) -> dict:
524
        # Load from model defaults
525
        if return_attention_mask is None:
526
            return_attention_mask = "attention_mask" in self.model_input_names or "attention_mask" in encoded_inputs
527

528
        required_input = encoded_inputs[self.model_input_names[0]]
529

530
        if padding == "longest":
531
            max_length = len(required_input)
532

533
        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
534
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
535

536
        needs_to_be_padded = padding != "do_not_pad" and len(required_input) != max_length
537

538
        # Initialize attention mask if not present.
539
        if return_attention_mask and "attention_mask" not in encoded_inputs:
540
            encoded_inputs["attention_mask"] = [1] * len(required_input)
541

542
        if needs_to_be_padded:
543
            difference = max_length - len(required_input)
544

545
            if self.padding_side == "right":
546
                if return_attention_mask:
547
                    encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
548
                if "token_type_ids" in encoded_inputs:
549
                    encoded_inputs["token_type_ids"] = (
550
                        encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
551
                    )
552
                if "special_tokens_mask" in encoded_inputs:
553
                    encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
554
                if "offset_mapping" in encoded_inputs:
555
                    encoded_inputs["offset_mapping"] = encoded_inputs["offset_mapping"] + [(0, 0)] * difference
556
                if "position_ids" in encoded_inputs:
557
                    encoded_inputs["position_ids"] = encoded_inputs["position_ids"] + [0] * difference
558
                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
559
            elif self.padding_side == "left":
560
                if return_attention_mask:
561
                    encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
562
                if "token_type_ids" in encoded_inputs:
563
                    encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
564
                        "token_type_ids"
565
                    ]
566
                if "special_tokens_mask" in encoded_inputs:
567
                    encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
568
                if "offset_mapping" in encoded_inputs:
569
                    encoded_inputs["offset_mapping"] = [(0, 0)] * difference + encoded_inputs["offset_mapping"]
570
                if "position_ids" in encoded_inputs:
571
                    encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
572
                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
573
            else:
574
                raise ValueError("Invalid padding strategy:" + str(self.padding_side))
575

576
        return encoded_inputs
577

578
    def __len__(self):
579
        return len(self.encoder) + len(self.special_tokens)
580

581
    def set_special_tokens(self, special_tokens):
582
        """Add a list of additional tokens to the encoder.
583
        The additional tokens are indexed starting from the last index of the
584
        current vocabulary in the order of the `special_tokens` list.
585
        """
586
        if not special_tokens:
587
            self.special_tokens = {}
588
            self.special_tokens_decoder = {}
589
            return
590
        self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
591
        self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
592
        logger.info("Special tokens {}".format(self.special_tokens))
593

594
    def bpe(self, token):
595
        if token in self.cache:
596
            return self.cache[token]
597
        word = tuple(token)
598
        pairs = get_pairs(word)
599

600
        if not pairs:
601
            return token
602

603
        while True:
604
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
605
            if bigram not in self.bpe_ranks:
606
                break
607
            first, second = bigram
608
            new_word = []
609
            i = 0
610
            while i < len(word):
611
                try:
612
                    j = word.index(first, i)
613
                    new_word.extend(word[i:j])
614
                    i = j
615
                except BaseException:
616
                    new_word.extend(word[i:])
617
                    break
618

619
                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
620
                    new_word.append(first + second)
621
                    i += 2
622
                else:
623
                    new_word.append(word[i])
624
                    i += 1
625
            new_word = tuple(new_word)
626
            word = new_word
627
            if len(word) == 1:
628
                break
629
            else:
630
                pairs = get_pairs(word)
631
        word = " ".join(word)
632
        self.cache[token] = word
633
        return word
634

635
    def tokenize(self, text):
636
        """Tokenize a string."""
637
        bpe_tokens = []
638
        for token in re.findall(self.pat, text):
639
            if sys.version_info[0] == 2:
640
                token = "".join(self.byte_encoder[ord(b)] for b in token)
641
            else:
642
                token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
643
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
644
        return bpe_tokens
645

646
    def convert_tokens_to_ids(self, tokens):
647
        """Converts a sequence of tokens into ids using the vocab."""
648
        ids = []
649
        if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
650
            if tokens in self.special_tokens:
651
                return self.special_tokens[tokens]
652
            else:
653
                return self.encoder.get(tokens, 0)
654
        for token in tokens:
655
            if token in self.special_tokens:
656
                ids.append(self.special_tokens[token])
657
            else:
658
                ids.append(self.encoder.get(token, 0))
659
        if len(ids) > self.max_len:
660
            warnings.warn(
661
                "Token indices sequence length is longer than the specified maximum "
662
                " sequence length for this OpenAI GPT model ({} > {}). Running this"
663
                " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
664
            )
665
        return ids
666

667
    def convert_ids_to_string(self, ids):
668
        """
669
        Converts a single index or a sequence of indices to texts.
670
        Args:
671
            ids (int|List[int]):
672
                The token id (or token ids) to be converted to text.
673
        Returns:
674
            str: The decoded text.
675
        Example:
676
            .. code-block::
677
                from paddlenlp.transformers import GPTTokenizer
678
                tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
679
                print(tokenizer.convert_ids_to_string(tokenizer.convert_ids_to_string([14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930]))
680
                # 'Welcome to use PaddlePaddle and PaddleNLP'
681
        """
682

683
        text = "".join([self.decoder[id] for id in ids])
684
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
685
        return text
686

687
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
688
        """Converts a sequence of ids in BPE tokens using the vocab."""
689
        tokens = []
690
        for i in ids:
691
            if i in self.special_tokens_decoder:
692
                if not skip_special_tokens:
693
                    tokens.append(self.special_tokens_decoder[i])
694
            else:
695
                tokens.append(self.decoder[i])
696
        return tokens
697

698
    def encode(self, text):
699
        return self.convert_tokens_to_ids(self.tokenize(text))
700

701
    def decode(self, tokens):
702
        text = "".join([self.decoder[token] if token in self.decoder.keys() else "" for token in tokens])
703
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
704
        return text
705

706
    def save_vocabulary(self, vocab_path):
707
        """Save the tokenizer vocabulary and merge files to a directory."""
708
        if not os.path.isdir(vocab_path):
709
            logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
710
            return
711
        vocab_file = os.path.join(vocab_path, VOCAB_NAME)
712
        merge_file = os.path.join(vocab_path, MERGES_NAME)
713
        special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
714

715
        with open(vocab_file, "w", encoding="utf-8") as f:
716
            f.write(json.dumps(self.encoder, ensure_ascii=False))
717

718
        index = 0
719
        with open(merge_file, "w", encoding="utf-8") as writer:
720
            writer.write("#version: 0.2\n")
721
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
722
                if index != token_index:
723
                    warnings.warn(
724
                        "Saving vocabulary to {}: BPE merge indices are not consecutive."
725
                        " Please check that the tokenizer is not corrupted!".format(merge_file)
726
                    )
727
                    index = token_index
728
                writer.write(" ".join(bpe_tokens) + "\n")
729
                index += 1
730

731
        index = len(self.encoder)
732
        with open(special_tokens_file, "w", encoding="utf-8") as writer:
733
            for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
734
                if index != token_index:
735
                    warnings.warn(
736
                        "Saving special tokens vocabulary to {}: BPE indices are not consecutive."
737
                        " Please check that the tokenizer is not corrupted!".format(special_tokens_file)
738
                    )
739
                    index = token_index
740
                writer.write(token + "\n")
741
                index += 1
742

743
        return vocab_file, merge_file, special_tokens_file
744

745
    @property
746
    def vocab_size(self):
747
        return len(self.encoder)
748

749
    @property
750
    def vocab(self):
751
        return self.encoder
752

753
    @property
754
    def inv_vocab(self):
755
        return self.decoder
756

757
    @property
758
    def eos_token_id(self):
759
        return self.eod_id
760

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.