paddlenlp

Форк
0
/
tokenization_utils_base.py 
1790 строк · 73.0 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2020 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# flake8: noqa
16
"""
17
Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user
18
fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary
19
of output with special method for the Fast tokenizers)
20
"""
21

22
import copy
23
import importlib
24
import json
25
import os
26
import re
27
import warnings
28
from collections import OrderedDict, UserDict
29
from collections.abc import Mapping
30
from contextlib import contextmanager
31
from dataclasses import dataclass, field
32
from typing import (
33
    TYPE_CHECKING,
34
    Any,
35
    Dict,
36
    List,
37
    NamedTuple,
38
    Optional,
39
    Sequence,
40
    Tuple,
41
    Union,
42
)
43

44
import numpy as np
45

46

47
def is_sentencepiece_available():
48
    return importlib.util.find_spec("sentencepiece") is not None
49

50

51
def is_tokenizers_available():
52
    return importlib.util.find_spec("tokenizers") is not None
53

54

55
if is_tokenizers_available():
56
    from tokenizers import AddedToken
57
else:
58

59
    @dataclass(frozen=True, eq=True)
60
    class AddedToken:
61
        """
62
        AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
63
        way it should behave.
64
        """
65

66
        content: str = field(default_factory=str)
67
        single_word: bool = False
68
        lstrip: bool = False
69
        rstrip: bool = False
70
        normalized: bool = True
71

72
        def __getstate__(self):
73
            return self.__dict__
74

75

76
TOKENIZER_MAPPING_NAMES = OrderedDict(
77
    [
78
        (
79
            "albert",
80
            (
81
                "AlbertTokenizer" if is_sentencepiece_available() else None,
82
                "AlbertTokenizerFast" if is_tokenizers_available() else None,
83
            ),
84
        ),
85
        ("bart", ("BartTokenizer", "BartTokenizerFast")),
86
        (
87
            "barthez",
88
            (
89
                "BarthezTokenizer" if is_sentencepiece_available() else None,
90
                "BarthezTokenizerFast" if is_tokenizers_available() else None,
91
            ),
92
        ),
93
        ("bartpho", ("BartphoTokenizer", None)),
94
        ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
95
        ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
96
        ("bert-japanese", ("BertJapaneseTokenizer", None)),
97
        ("bertweet", ("BertweetTokenizer", None)),
98
        (
99
            "big_bird",
100
            (
101
                "BigBirdTokenizer" if is_sentencepiece_available() else None,
102
                "BigBirdTokenizerFast" if is_tokenizers_available() else None,
103
            ),
104
        ),
105
        ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
106
        ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
107
        ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
108
        ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
109
        ("byt5", ("ByT5Tokenizer", None)),
110
        (
111
            "camembert",
112
            (
113
                "CamembertTokenizer" if is_sentencepiece_available() else None,
114
                "CamembertTokenizerFast" if is_tokenizers_available() else None,
115
            ),
116
        ),
117
        ("canine", ("CanineTokenizer", None)),
118
        (
119
            "clip",
120
            (
121
                "CLIPTokenizer",
122
                "CLIPTokenizerFast" if is_tokenizers_available() else None,
123
            ),
124
        ),
125
        ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
126
        (
127
            "cpm",
128
            (
129
                "CpmTokenizer" if is_sentencepiece_available() else None,
130
                "CpmTokenizerFast" if is_tokenizers_available() else None,
131
            ),
132
        ),
133
        ("ctrl", ("CTRLTokenizer", None)),
134
        ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
135
        ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
136
        (
137
            "deberta-v2",
138
            (
139
                "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
140
                "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
141
            ),
142
        ),
143
        ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
144
        (
145
            "dpr",
146
            (
147
                "DPRQuestionEncoderTokenizer",
148
                "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
149
            ),
150
        ),
151
        ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
152
        ("flaubert", ("FlaubertTokenizer", None)),
153
        ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
154
        ("fsmt", ("FSMTTokenizer", None)),
155
        ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
156
        ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
157
        ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
158
        ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
159
        ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
160
        ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
161
        ("hubert", ("Wav2Vec2CTCTokenizer", None)),
162
        ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
163
        ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
164
        ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
165
        ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
166
        ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
167
        ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
168
        ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
169
        (
170
            "longt5",
171
            (
172
                "T5Tokenizer" if is_sentencepiece_available() else None,
173
                "T5TokenizerFast" if is_tokenizers_available() else None,
174
            ),
175
        ),
176
        ("luke", ("LukeTokenizer", None)),
177
        ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
178
        ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
179
        ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
180
        (
181
            "mbart",
182
            (
183
                "MBartTokenizer" if is_sentencepiece_available() else None,
184
                "MBartTokenizerFast" if is_tokenizers_available() else None,
185
            ),
186
        ),
187
        (
188
            "mbart50",
189
            (
190
                "MBart50Tokenizer" if is_sentencepiece_available() else None,
191
                "MBart50TokenizerFast" if is_tokenizers_available() else None,
192
            ),
193
        ),
194
        ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
195
        ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
196
        ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
197
        ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
198
        (
199
            "mt5",
200
            (
201
                "MT5Tokenizer" if is_sentencepiece_available() else None,
202
                "MT5TokenizerFast" if is_tokenizers_available() else None,
203
            ),
204
        ),
205
        (
206
            "nystromformer",
207
            (
208
                "AlbertTokenizer" if is_sentencepiece_available() else None,
209
                "AlbertTokenizerFast" if is_tokenizers_available() else None,
210
            ),
211
        ),
212
        ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
213
        ("opt", ("GPT2Tokenizer", None)),
214
        (
215
            "pegasus",
216
            (
217
                "PegasusTokenizer" if is_sentencepiece_available() else None,
218
                "PegasusTokenizerFast" if is_tokenizers_available() else None,
219
            ),
220
        ),
221
        (
222
            "perceiver",
223
            (
224
                "PerceiverTokenizer",
225
                None,
226
            ),
227
        ),
228
        ("phobert", ("PhobertTokenizer", None)),
229
        ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
230
        ("prophetnet", ("ProphetNetTokenizer", None)),
231
        ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
232
        ("rag", ("RagTokenizer", None)),
233
        ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
234
        (
235
            "reformer",
236
            (
237
                "ReformerTokenizer" if is_sentencepiece_available() else None,
238
                "ReformerTokenizerFast" if is_tokenizers_available() else None,
239
            ),
240
        ),
241
        (
242
            "rembert",
243
            (
244
                "RemBertTokenizer" if is_sentencepiece_available() else None,
245
                "RemBertTokenizerFast" if is_tokenizers_available() else None,
246
            ),
247
        ),
248
        ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
249
        ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
250
        ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
251
        ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
252
        ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
253
        ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
254
        (
255
            "squeezebert",
256
            ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
257
        ),
258
        (
259
            "t5",
260
            (
261
                "T5Tokenizer" if is_sentencepiece_available() else None,
262
                "T5TokenizerFast" if is_tokenizers_available() else None,
263
            ),
264
        ),
265
        ("tapas", ("TapasTokenizer", None)),
266
        ("tapex", ("TapexTokenizer", None)),
267
        ("transfo-xl", ("TransfoXLTokenizer", None)),
268
        ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
269
        ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
270
        ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
271
        ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
272
        ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
273
        (
274
            "xglm",
275
            (
276
                "XGLMTokenizer" if is_sentencepiece_available() else None,
277
                "XGLMTokenizerFast" if is_tokenizers_available() else None,
278
            ),
279
        ),
280
        ("xlm", ("XLMTokenizer", None)),
281
        ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
282
        (
283
            "xlm-roberta",
284
            (
285
                "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
286
                "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
287
            ),
288
        ),
289
        ("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
290
        (
291
            "xlnet",
292
            (
293
                "XLNetTokenizer" if is_sentencepiece_available() else None,
294
                "XLNetTokenizerFast" if is_tokenizers_available() else None,
295
            ),
296
        ),
297
        (
298
            "yoso",
299
            (
300
                "AlbertTokenizer" if is_sentencepiece_available() else None,
301
                "AlbertTokenizerFast" if is_tokenizers_available() else None,
302
            ),
303
        ),
304
    ]
305
)
306

307
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
308
    [
309
        ("openai-gpt", "openai"),
310
        ("data2vec-audio", "data2vec"),
311
        ("data2vec-text", "data2vec"),
312
        ("data2vec-vision", "data2vec"),
313
    ]
314
)
315

316

317
def model_type_to_module_name(key):
318
    """Converts a config key to the corresponding module."""
319
    # Special treatment
320
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
321
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
322

323
    return key.replace("-", "_")
324

325

326
class _LazyConfigMapping(OrderedDict):
327
    """
328
    A dictionary that lazily load its values when they are requested.
329
    """
330

331
    def __init__(self, mapping):
332
        self._mapping = mapping
333
        self._extra_content = {}
334
        self._modules = {}
335

336
    def __getitem__(self, key):
337
        if key in self._extra_content:
338
            return self._extra_content[key]
339
        if key not in self._mapping:
340
            raise KeyError(key)
341
        value = self._mapping[key]
342
        module_name = model_type_to_module_name(key)
343
        if module_name not in self._modules:
344

345
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
346
        if hasattr(self._modules[module_name], value):
347
            return getattr(self._modules[module_name], value)
348

349
        # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
350
        # object at the top level.
351
        transformers_module = importlib.import_module("transformers")
352
        return getattr(transformers_module, value)
353

354
    def keys(self):
355
        return list(self._mapping.keys()) + list(self._extra_content.keys())
356

357
    def values(self):
358
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
359

360
    def items(self):
361
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
362

363
    def __iter__(self):
364
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
365

366
    def __contains__(self, item):
367
        return item in self._mapping or item in self._extra_content
368

369
    def register(self, key, value):
370
        """
371
        Register a new configuration in this mapping.
372
        """
373
        if key in self._mapping.keys():
374
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
375
        self._extra_content[key] = value
376

377

378
class Trie:
379
    """
380
    Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
381
    Loose reference https://en.wikipedia.org/wiki/Trie
382
    """
383

384
    def __init__(self):
385
        self.data = {}
386

387
    def add(self, word: str):
388
        """
389
        Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
390
        The special key `""` is used to represent termination.
391

392
        This function is idempotent, adding twice the same word will leave the trie unchanged
393

394
        Example:
395

396
        ```python
397
        >>> trie = Trie()
398
        >>> trie.add("Hello 友達")
399
        >>> trie.data
400
        {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
401

402
        >>> trie.add("Hello")
403
        >>> trie.data
404
        {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
405
        ```
406
        """
407
        if not word:
408
            # Prevent empty string
409
            return
410
        ref = self.data
411
        for char in word:
412
            ref[char] = char in ref and ref[char] or {}
413
            ref = ref[char]
414
        ref[""] = 1
415

416
    def split(self, text: str) -> List[str]:
417
        """
418
        Will look for the words added to the trie within `text`. Output is the original string splitted along the
419
        boundaries of the words found.
420

421
        This trie will match the longest possible word first !
422

423
        Example:
424

425
        ```python
426
        >>> trie = Trie()
427
        >>> trie.split("[CLS] This is a extra_id_100")
428
        ["[CLS] This is a extra_id_100"]
429

430
        >>> trie.add("[CLS]")
431
        >>> trie.add("extra_id_1")
432
        >>> trie.add("extra_id_100")
433
        >>> trie.split("[CLS] This is a extra_id_100")
434
        ["[CLS]", " This is a ", "extra_id_100"]
435
        ```
436
        """
437
        # indexes are counted left of the chars index.
438
        # "hello", index 0, is left of h, index 1 is between h and e.
439
        # index 5 is right of the "o".
440

441
        # States are going to capture every possible start (indexes as above)
442
        # as keys, and have as values, a pointer to the position in the trie
443
        # where we're at. This is a partial match for now.
444
        # This enables to keep track of multiple matches while we're iterating
445
        # the string
446
        # If the trie contains, "blowing", and "lower" and we encounter the
447
        # string "blower", we need to split into ["b", "lower"].
448
        # This is where we need to keep track of multiple possible starts.
449
        states = OrderedDict()
450

451
        # This will contain every indices where we need
452
        # to cut.
453
        # We force to cut at offset 0 and len(text) (added later)
454
        offsets = [0]
455

456
        # This is used by the lookahead which needs to skip over
457
        # some text where the full match exceeded the place in the initial
458
        # for loop
459
        skip = 0
460
        # Main loop, Giving this algorithm O(n) complexity
461
        for current, current_char in enumerate(text):
462
            if skip and current < skip:
463
                # Prevents the lookahead for matching twice
464
                # like extra_id_100 and id_100
465
                continue
466

467
            # This will track every state
468
            # that stop matching, we need to stop tracking them.
469
            # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
470
            # fail on "b", we need to remove 0 from the valid states.
471
            to_remove = set()
472
            # Whenever we found a match, we need to drop everything
473
            # this is a greedy algorithm, it will match on the first found token
474
            reset = False
475

476
            # In this case, we already have partial matches (But unfinished)
477
            for start, trie_pointer in states.items():
478
                if "" in trie_pointer:
479
                    # This is a final match, we need to reset and
480
                    # store the results in `offsets`.
481

482
                    # Lookahead to match longest first
483
                    # Important in case of extra_id_1 vs extra_id_100
484
                    # Here we are also actively looking for other earlier partial
485
                    # matches
486
                    # "[CLS]", "L", we need to match CLS even if L is special
487
                    for lookstart, looktrie_pointer in states.items():
488
                        if lookstart > start:
489
                            # This partial match is later, we can stop looking
490
                            break
491
                        elif lookstart < start:
492
                            # This partial match is earlier, the trie pointer
493
                            # was already updated, so index is + 1
494
                            lookahead_index = current + 1
495
                            end = current + 1
496
                        else:
497
                            # Here lookstart == start and
498
                            #      looktrie_pointer == trie_pointer
499
                            # It wasn't updated yet so indices are current ones
500
                            lookahead_index = current
501
                            end = current
502
                        next_char = text[lookahead_index] if lookahead_index < len(text) else None
503
                        if "" in looktrie_pointer:
504
                            start = lookstart
505
                            end = lookahead_index
506
                            skip = lookahead_index
507

508
                        while next_char in looktrie_pointer:
509
                            looktrie_pointer = looktrie_pointer[next_char]
510
                            lookahead_index += 1
511
                            if "" in looktrie_pointer:
512
                                start = lookstart
513
                                end = lookahead_index
514
                                skip = lookahead_index
515

516
                            if lookahead_index == len(text):
517
                                # End of string
518
                                break
519
                            next_char = text[lookahead_index]
520
                        # End lookahead
521

522
                        # Storing and resetting
523
                    offsets.append(start)
524
                    offsets.append(end)
525
                    reset = True
526
                    break
527
                elif current_char in trie_pointer:
528
                    # The current character being looked at has a match within the trie
529
                    # update the pointer (it will be stored back into states later).
530
                    trie_pointer = trie_pointer[current_char]
531

532
                    # Storing back the new pointer into the states.
533
                    # Partial matches got longer by one.
534
                    states[start] = trie_pointer
535
                else:
536
                    # The new character has not match in the trie, we need
537
                    # to stop keeping track of this partial match.
538
                    # We can't do it directly within the loop because of how
539
                    # python iteration works
540
                    to_remove.add(start)
541

542
            # Either clearing the full start (we found a real match)
543
            # Or clearing only the partial matches that didn't work.
544
            if reset:
545
                states = {}
546
            else:
547
                for start in to_remove:
548
                    del states[start]
549

550
            # If this character is a starting character within the trie
551
            # start keeping track of this partial match.
552
            if current >= skip and current_char in self.data:
553
                states[current] = self.data[current_char]
554

555
        # We have a cut at the end with states.
556
        for start, trie_pointer in states.items():
557
            if "" in trie_pointer:
558
                # This is a final match, we need to reset and
559
                # store the results in `offsets`.
560
                end = len(text)
561
                offsets.append(start)
562
                offsets.append(end)
563
                # Longest cut is always the one with lower start so the first
564
                # item so we need to break.
565
                break
566

567
        return self.cut_text(text, offsets)
568

569
    def cut_text(self, text, offsets):
570
        # We have all the offsets now, we just need to do the actual splitting.
571
        # We need to eventually add the first part of the string and the eventual
572
        # last part.
573
        offsets.append(len(text))
574
        tokens = []
575
        start = 0
576
        for end in offsets:
577
            if start > end:
578
                logger.error(
579
                    "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
580
                    " anyway."
581
                )
582
                continue
583
            elif start == end:
584
                # This might happen if there's a match at index 0
585
                # we're also preventing zero-width cuts in case of two
586
                # consecutive matches
587
                continue
588
            tokens.append(text[start:end])
589
            start = end
590

591
        return tokens
592

593

594
from enum import Enum
595

596

597
class ExplicitEnum(Enum):
598
    """
599
    Enum with more explicit error message for missing values.
600
    """
601

602
    @classmethod
603
    def _missing_(cls, value):
604
        raise ValueError(
605
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
606
        )
607

608

609
class TensorType(ExplicitEnum):
610
    """
611
    Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
612
    tab-completion in an IDE.
613
    """
614

615
    PADDLE = "paddle"
616
    PYTORCH = "pt"
617
    TENSORFLOW = "tf"
618
    NUMPY = "np"
619
    JAX = "jax"
620

621

622
class BatchEncoding(UserDict):
623
    """
624
    Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`],
625
    [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and
626
    [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc).
627

628
    This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes
629
    utility methods to map from word/character space to token space.
630

631
    Args:
632
        data (`dict`):
633
            Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods
634
            ('input_ids', 'attention_mask', etc.).
635
        encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*):
636
            If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character
637
            space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this
638
            information.
639
        tensor_type (`Union[None, str, TensorType]`, *optional*):
640
            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
641
            initialization.
642
        prepend_batch_axis (`bool`, *optional*, defaults to `False`):
643
            Whether or not to add a batch axis when converting to tensors (see `tensor_type` above).
644
        n_sequences (`Optional[int]`, *optional*):
645
            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
646
            initialization.
647
    """
648

649
    def __init__(
650
        self,
651
        data=None,
652
        encoding=None,
653
        tensor_type=None,
654
        prepend_batch_axis: bool = False,
655
        n_sequences=None,
656
    ):
657
        super().__init__(data)
658

659
        # if isinstance(encoding, EncodingFast):
660
        #    encoding = [encoding]
661

662
        self._encodings = encoding
663

664
        if n_sequences is None and encoding is not None and len(encoding):
665
            n_sequences = encoding[0].n_sequences
666

667
        self._n_sequences = n_sequences
668

669
        self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
670

671
    @property
672
    def n_sequences(self) -> Optional[int]:
673
        """
674
        `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this
675
        [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of
676
        sentences)
677
        """
678
        return self._n_sequences
679

680
    @property
681
    def is_fast(self) -> bool:
682
        """
683
        `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`]
684
        or not.
685
        """
686
        return self._encodings is not None
687

688
    # def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
689

690
    def __getitem__(self, item):
691
        """
692
        If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask',
693
        etc.).
694

695
        If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`.
696
        """
697
        if isinstance(item, str):
698
            return self.data[item]
699
        elif self._encodings is not None:
700
            return self._encodings[item]
701
        else:
702
            raise KeyError(
703
                "Indexing with integers (to access backend Encoding for a given batch index) "
704
                "is not available when using Python based tokenizers"
705
            )
706

707
    def __getattr__(self, item: str):
708
        try:
709
            return self.data[item]
710
        except KeyError:
711
            raise AttributeError
712

713
    def __getstate__(self):
714
        return {"data": self.data, "encodings": self._encodings}
715

716
    def __setstate__(self, state):
717
        if "data" in state:
718
            self.data = state["data"]
719

720
        if "encodings" in state:
721
            self._encodings = state["encodings"]
722

723
    def keys(self):
724
        return self.data.keys()
725

726
    def values(self):
727
        return self.data.values()
728

729
    def items(self):
730
        return self.data.items()
731

732
    # After this point:
733
    # Extended properties and methods only available for fast (Rust-based) tokenizers
734
    # provided by HuggingFace tokenizers library.
735

736
    @property
737
    def encodings(self):
738
        """
739
        `Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if
740
        the input was tokenized through Python (i.e., not a fast) tokenizer.
741
        """
742
        return self._encodings
743

744
    def tokens(self, batch_index=0):
745
        """
746
        Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to
747
        integer indices) at a given batch index (only works for the output of a fast tokenizer).
748

749
        Args:
750
            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
751

752
        Returns:
753
            `List[str]`: The list of tokens at that index.
754
        """
755
        if not self._encodings:
756
            raise ValueError("tokens() is not available when using Python-based tokenizers")
757
        return self._encodings[batch_index].tokens
758

759
    def sequence_ids(self, batch_index=0):
760
        """
761
        Return a list mapping the tokens to the id of their original sentences:
762

763
            - `None` for special tokens added around or between sequences,
764
            - `0` for tokens corresponding to words in the first sequence,
765
            - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly
766
              encoded.
767

768
        Args:
769
            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
770

771
        Returns:
772
            `List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added
773
            by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding
774
            sequence.
775
        """
776
        if not self._encodings:
777
            raise ValueError("sequence_ids() is not available when using Python-based tokenizers")
778
        return self._encodings[batch_index].sequence_ids
779

780
    def words(self, batch_index=0):
781
        """
782
        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
783

784
        Args:
785
            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
786

787
        Returns:
788
            `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
789
            tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
790
            (several tokens will be mapped to the same word index if they are parts of that word).
791
        """
792
        if not self._encodings:
793
            raise ValueError("words() is not available when using Python-based tokenizers")
794
        warnings.warn(
795
            "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
796
            "but more self-explanatory `BatchEncoding.word_ids()` property.",
797
            FutureWarning,
798
        )
799
        return self.word_ids(batch_index)
800

801
    def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
802
        """
803
        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
804

805
        Args:
806
            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
807

808
        Returns:
809
            `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
810
            tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
811
            (several tokens will be mapped to the same word index if they are parts of that word).
812
        """
813
        if not self._encodings:
814
            raise ValueError("word_ids() is not available when using Python-based tokenizers")
815
        return self._encodings[batch_index].word_ids
816

817
    def token_to_sequence(self, batch_or_token_index, token_index):
818
        """
819
        Get the index of the sequence represented by the given token. In the general use case, this method returns `0`
820
        for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair
821

822
        Can be called as:
823

824
        - `self.token_to_sequence(token_index)` if batch size is 1
825
        - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1
826

827
        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
828
        words are defined by the user). In this case it allows to easily associate encoded tokens with provided
829
        tokenized words.
830

831
        Args:
832
            batch_or_token_index (`int`):
833
                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
834
                the token in the sequence.
835
            token_index (`int`, *optional*):
836
                If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
837
                sequence.
838

839
        Returns:
840
            `int`: Index of the word in the input sequence.
841
        """
842

843
        if not self._encodings:
844
            raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
845
        if token_index is not None:
846
            batch_index = batch_or_token_index
847
        else:
848
            batch_index = 0
849
            token_index = batch_or_token_index
850
        if batch_index < 0:
851
            batch_index = self._batch_size + batch_index
852
        if token_index < 0:
853
            token_index = self._seq_len + token_index
854
        return self._encodings[batch_index].token_to_sequence(token_index)
855

856
    def token_to_word(self, batch_or_token_index, token_index=None):
857
        """
858
        Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch.
859

860
        Can be called as:
861

862
        - `self.token_to_word(token_index)` if batch size is 1
863
        - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1
864

865
        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
866
        words are defined by the user). In this case it allows to easily associate encoded tokens with provided
867
        tokenized words.
868

869
        Args:
870
            batch_or_token_index (`int`):
871
                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
872
                the token in the sequence.
873
            token_index (`int`, *optional*):
874
                If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
875
                sequence.
876

877
        Returns:
878
            `int`: Index of the word in the input sequence.
879
        """
880

881
        if not self._encodings:
882
            raise ValueError("token_to_word() is not available when using Python based tokenizers")
883
        if token_index is not None:
884
            batch_index = batch_or_token_index
885
        else:
886
            batch_index = 0
887
            token_index = batch_or_token_index
888
        if batch_index < 0:
889
            batch_index = self._batch_size + batch_index
890
        if token_index < 0:
891
            token_index = self._seq_len + token_index
892
        return self._encodings[batch_index].token_to_word(token_index)
893

894
    def word_to_tokens(self, batch_or_word_index, word_index=None, sequence_index=0):
895
        """
896
        Get the encoded token span corresponding to a word in a sequence of the batch.
897

898
        Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with:
899

900
        - **start** -- Index of the first token.
901
        - **end** -- Index of the token following the last token.
902

903
        Can be called as:
904

905
        - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1
906
        - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to
907
          1
908

909
        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
910
        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
911
        words.
912

913
        Args:
914
            batch_or_word_index (`int`):
915
                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
916
                the word in the sequence.
917
            word_index (`int`, *optional*):
918
                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
919
                sequence.
920
            sequence_index (`int`, *optional*, defaults to 0):
921
                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
922
                or 1) the provided word index belongs to.
923

924
        Returns:
925
            Optional [`~tokenization_utils_base.TokenSpan`] Span of tokens in the encoded sequence. Returns `None` if
926
            no tokens correspond to the word.
927
        """
928

929
        if not self._encodings:
930
            raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
931
        if word_index is not None:
932
            batch_index = batch_or_word_index
933
        else:
934
            batch_index = 0
935
            word_index = batch_or_word_index
936
        if batch_index < 0:
937
            batch_index = self._batch_size + batch_index
938
        if word_index < 0:
939
            word_index = self._seq_len + word_index
940
        span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
941
        return TokenSpan(*span) if span is not None else None
942

943
    def token_to_chars(self, batch_or_token_index: int, token_index=None):
944
        """
945
        Get the character span corresponding to an encoded token in a sequence of the batch.
946

947
        Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with:
948

949
        - **start** -- Index of the first character in the original string associated to the token.
950
        - **end** -- Index of the character following the last character in the original string associated to the
951
          token.
952

953
        Can be called as:
954

955
        - `self.token_to_chars(token_index)` if batch size is 1
956
        - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1
957

958
        Args:
959
            batch_or_token_index (`int`):
960
                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
961
                the token in the sequence.
962
            token_index (`int`, *optional*):
963
                If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in
964
                the sequence.
965

966
        Returns:
967
            [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token
968
            (e.g. <s>, </s>) doesn't correspond to any chars in the origin string.
969
        """
970

971
        if not self._encodings:
972
            raise ValueError("token_to_chars() is not available when using Python based tokenizers")
973
        if token_index is not None:
974
            batch_index = batch_or_token_index
975
        else:
976
            batch_index = 0
977
            token_index = batch_or_token_index
978
        span_indices = self._encodings[batch_index].token_to_chars(token_index)
979

980
        return CharSpan(*span_indices) if span_indices is not None else None
981

982
    def char_to_token(
983
        self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
984
    ) -> int:
985
        """
986
        Get the index of the token in the encoded output comprising a character in the original string for a sequence
987
        of the batch.
988

989
        Can be called as:
990

991
        - `self.char_to_token(char_index)` if batch size is 1
992
        - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1
993

994
        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
995
        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
996
        words.
997

998
        Args:
999
            batch_or_char_index (`int`):
1000
                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
1001
                the word in the sequence
1002
            char_index (`int`, *optional*):
1003
                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
1004
                sequence.
1005
            sequence_index (`int`, *optional*, defaults to 0):
1006
                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
1007
                or 1) the provided character index belongs to.
1008

1009

1010
        Returns:
1011
            `int`: Index of the token.
1012
        """
1013

1014
        if not self._encodings:
1015
            raise ValueError("char_to_token() is not available when using Python based tokenizers")
1016
        if char_index is not None:
1017
            batch_index = batch_or_char_index
1018
        else:
1019
            batch_index = 0
1020
            char_index = batch_or_char_index
1021
        return self._encodings[batch_index].char_to_token(char_index, sequence_index)
1022

1023
    def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0):
1024
        """
1025
        Get the character span in the original string corresponding to given word in a sequence of the batch.
1026

1027
        Character spans are returned as a CharSpan NamedTuple with:
1028

1029
        - start: index of the first character in the original string
1030
        - end: index of the character following the last character in the original string
1031

1032
        Can be called as:
1033

1034
        - `self.word_to_chars(word_index)` if batch size is 1
1035
        - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1
1036

1037
        Args:
1038
            batch_or_word_index (`int`):
1039
                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
1040
                the word in the sequence
1041
            word_index (`int`, *optional*):
1042
                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
1043
                sequence.
1044
            sequence_index (`int`, *optional*, defaults to 0):
1045
                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
1046
                or 1) the provided word index belongs to.
1047

1048
        Returns:
1049
            `CharSpan` or `List[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan
1050
            are NamedTuple with:
1051

1052
                - start: index of the first character associated to the token in the original string
1053
                - end: index of the character following the last character associated to the token in the original
1054
                  string
1055
        """
1056

1057
        if not self._encodings:
1058
            raise ValueError("word_to_chars() is not available when using Python based tokenizers")
1059
        if word_index is not None:
1060
            batch_index = batch_or_word_index
1061
        else:
1062
            batch_index = 0
1063
            word_index = batch_or_word_index
1064
        return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
1065

1066
    def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
1067
        """
1068
        Get the word in the original string corresponding to a character in the original string of a sequence of the
1069
        batch.
1070

1071
        Can be called as:
1072

1073
        - `self.char_to_word(char_index)` if batch size is 1
1074
        - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1
1075

1076
        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
1077
        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
1078
        words.
1079

1080
        Args:
1081
            batch_or_char_index (`int`):
1082
                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
1083
                the character in the original string.
1084
            char_index (`int`, *optional*):
1085
                If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the
1086
                original string.
1087
            sequence_index (`int`, *optional*, defaults to 0):
1088
                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
1089
                or 1) the provided character index belongs to.
1090

1091

1092
        Returns:
1093
            `int` or `List[int]`: Index or indices of the associated encoded token(s).
1094
        """
1095

1096
        if not self._encodings:
1097
            raise ValueError("char_to_word() is not available when using Python based tokenizers")
1098
        if char_index is not None:
1099
            batch_index = batch_or_char_index
1100
        else:
1101
            batch_index = 0
1102
            char_index = batch_or_char_index
1103
        return self._encodings[batch_index].char_to_word(char_index, sequence_index)
1104

1105
    def convert_to_tensors(self, tensor_type=None, prepend_batch_axis: bool = False):
1106
        """
1107
        Convert the inner content to tensors.
1108

1109
        Args:
1110
            tensor_type (`str` or [`~utils.TensorType`], *optional*):
1111
                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
1112
                `None`, no modification is done.
1113
            prepend_batch_axis (`int`, *optional*, defaults to `False`):
1114
                Whether or not to add the batch dimension during the conversion.
1115
        """
1116
        if tensor_type is None:
1117
            return self
1118

1119
        # Get a function reference for the correct framework
1120
        if tensor_type == "paddle":
1121
            import paddle
1122

1123
            as_tensor = paddle.to_tensor
1124
            is_tensor = paddle.is_tensor
1125
        else:
1126
            as_tensor = np.asarray
1127
            is_tensor = _is_numpy
1128
        # (mfuntowicz: This code is unreachable)
1129
        # else:
1130
        #     raise ImportError(
1131
        #         f"Unable to convert output to tensors format {tensor_type}"
1132
        #     )
1133

1134
        # Do the tensor conversion in batch
1135
        for key, value in self.items():
1136
            try:
1137
                if prepend_batch_axis:
1138
                    value = [value]
1139

1140
                if not is_tensor(value):
1141
                    tensor = as_tensor(value)
1142

1143
                    # Removing this for now in favor of controlling the shape with `prepend_batch_axis`
1144
                    # # at-least2d
1145
                    # if tensor.ndim > 2:
1146
                    #     tensor = tensor.squeeze(0)
1147
                    # elif tensor.ndim < 2:
1148
                    #     tensor = tensor[None, :]
1149

1150
                    self[key] = tensor
1151
            except:  # noqa E722
1152
                if key == "overflowing_tokens":
1153
                    raise ValueError(
1154
                        "Unable to create tensor returning overflowing tokens of different lengths. "
1155
                        "Please see if a fast version of this tokenizer is available to have this feature available."
1156
                    )
1157
                raise ValueError(
1158
                    "Unable to create tensor, you should probably activate truncation and/or padding "
1159
                    "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
1160
                )
1161

1162
        return self
1163

1164

1165
class TruncationStrategy(ExplicitEnum):
1166
    """
1167
    Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
1168
    an IDE.
1169
    """
1170

1171
    ONLY_FIRST = "only_first"
1172
    ONLY_SECOND = "only_second"
1173
    LONGEST_FIRST = "longest_first"
1174
    DO_NOT_TRUNCATE = "do_not_truncate"
1175

1176

1177
class PaddingStrategy(ExplicitEnum):
1178
    """
1179
    Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
1180
    IDE.
1181
    """
1182

1183
    LONGEST = "longest"
1184
    MAX_LENGTH = "max_length"
1185
    DO_NOT_PAD = "do_not_pad"
1186

1187

1188
class SpecialTokensMixin:
1189
    """
1190
    A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to
1191
    special tokens. In particular, this class hold the attributes which can be used to directly access these special
1192
    tokens in a model-independent manner and allow to set and update the special tokens.
1193

1194
    Args:
1195
        bos_token (`str` or `tokenizers.AddedToken`, *optional*):
1196
            A special token representing the beginning of a sentence.
1197
        eos_token (`str` or `tokenizers.AddedToken`, *optional*):
1198
            A special token representing the end of a sentence.
1199
        unk_token (`str` or `tokenizers.AddedToken`, *optional*):
1200
            A special token representing an out-of-vocabulary token.
1201
        sep_token (`str` or `tokenizers.AddedToken`, *optional*):
1202
            A special token separating two different sentences in the same input (used by BERT for instance).
1203
        pad_token (`str` or `tokenizers.AddedToken`, *optional*):
1204
            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
1205
            attention mechanisms or loss computation.
1206
        cls_token (`str` or `tokenizers.AddedToken`, *optional*):
1207
            A special token representing the class of the input (used by BERT for instance).
1208
        mask_token (`str` or `tokenizers.AddedToken`, *optional*):
1209
            A special token representing a masked token (used by masked-language modeling pretraining objectives, like
1210
            BERT).
1211
        additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
1212
            A tuple or a list of additional special tokens.
1213
    """
1214

1215
    SPECIAL_TOKENS_ATTRIBUTES = [
1216
        "bos_token",
1217
        "eos_token",
1218
        "unk_token",
1219
        "sep_token",
1220
        "pad_token",
1221
        "cls_token",
1222
        "mask_token",
1223
        "additional_special_tokens",
1224
    ]
1225

1226
    def __init__(self, verbose=True, **kwargs):
1227
        self._bos_token = None
1228
        self._eos_token = None
1229
        self._unk_token = None
1230
        self._sep_token = None
1231
        self._pad_token = None
1232
        self._cls_token = None
1233
        self._mask_token = None
1234
        self._pad_token_type_id = 0
1235
        self._additional_special_tokens = []
1236
        self.verbose = verbose
1237
        self.added_tokens_encoder: Dict[str, int] = {}
1238
        self.added_tokens_decoder: Dict[int, str] = {}
1239
        self.unique_no_split_tokens: List[str] = []
1240
        self.tokens_trie = Trie()
1241

1242
        self._decode_use_source_tokenizer = False
1243

1244
        # We directly set the hidden value to allow initialization with special tokens
1245
        # which are not yet in the vocabulary. Necessary for serialization/de-serialization
1246
        # TODO clean this up at some point (probably by switching to fast tokenizers)
1247
        for key, value in kwargs.items():
1248
            if value is None:
1249
                continue
1250
            if key in self.SPECIAL_TOKENS_ATTRIBUTES:
1251
                if key == "additional_special_tokens":
1252
                    assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
1253
                    assert all(
1254
                        isinstance(t, (str, AddedToken)) for t in value
1255
                    ), "One of the tokens is not a string or an AddedToken"
1256
                    setattr(self, key, value)
1257
                elif isinstance(value, (str, AddedToken)):
1258
                    setattr(self, key, value)
1259
                else:
1260
                    raise TypeError(f"special token {key} has to be either str or AddedToken but got: {type(value)}")
1261

1262
    def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
1263
        """
1264
        Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
1265
        vocabulary.
1266

1267
        Args:
1268
            tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
1269

1270
        Returns:
1271
            `int` or `List[int]`: The token id or list of token ids.
1272
        """
1273
        if tokens is None:
1274
            return None
1275

1276
        if isinstance(tokens, str):
1277
            return self._convert_token_to_id_with_added_voc(tokens)
1278

1279
        ids = []
1280
        for token in tokens:
1281
            ids.append(self._convert_token_to_id_with_added_voc(token))
1282
        return ids
1283

1284
    def _convert_token_to_id_with_added_voc(self, token):
1285
        if token is None:
1286
            return None
1287

1288
        if token in self.added_tokens_encoder:
1289
            return self.added_tokens_encoder[token]
1290
        return self._convert_token_to_id(token)
1291

1292
    def _convert_token_to_id(self, token):
1293
        """Converts a token (str) in an id using the vocab."""
1294
        if token.startswith("<extra_id_"):
1295
            match = re.match(r"<extra_id_(\d+)>", token)
1296
            num = int(match.group(1))
1297
            return self.vocab_size - num - 1
1298
        return self.sp_model.piece_to_id(token)
1299

1300
    def sanitize_special_tokens(self) -> int:
1301
        """
1302
        Make sure that all the special tokens attributes of the tokenizer (`tokenizer.mask_token`,
1303
        `tokenizer.cls_token`, etc.) are in the vocabulary.
1304

1305
        Add the missing ones to the vocabulary if needed.
1306

1307
        Return:
1308
            `int`: The number of tokens added in the vocabulary during the operation.
1309
        """
1310
        return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
1311

1312
    def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
1313
        """
1314
        Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
1315
        special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
1316
        current vocabulary).
1317

1318
        Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
1319
        matrix of the model so that its embedding matrix matches the tokenizer.
1320

1321
        In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
1322

1323
        Using `add_special_tokens` will ensure your special tokens can be used in several ways:
1324

1325
        - Special tokens are carefully handled by the tokenizer (they are never split).
1326
        - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This
1327
          makes it easy to develop model-agnostic training and fine-tuning scripts.
1328

1329
        When possible, special tokens are already registered for provided pretrained models (for instance
1330
        [`BertTokenizer`] `cls_token` is already registered to be :obj*'[CLS]'* and XLM's one is also registered to be
1331
        `'</s>'`).
1332

1333
        Args:
1334
            special_tokens_dict (dictionary *str* to *str* or `tokenizers.AddedToken`):
1335
                Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`,
1336
                `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`].
1337

1338
                Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
1339
                assign the index of the `unk_token` to them).
1340

1341
        Returns:
1342
            `int`: Number of tokens added to the vocabulary.
1343

1344
        Examples:
1345

1346
        ```python
1347
        # Let's see how to add a new classification token to GPT-2
1348
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
1349
        model = GPT2Model.from_pretrained("gpt2")
1350

1351
        special_tokens_dict = {"cls_token": "<CLS>"}
1352

1353
        num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
1354
        print("We have added", num_added_toks, "tokens")
1355
        # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
1356
        model.resize_token_embeddings(len(tokenizer))
1357

1358
        assert tokenizer.cls_token == "<CLS>"
1359
        ```"""
1360
        if not special_tokens_dict:
1361
            return 0
1362

1363
        added_tokens = 0
1364
        for key, value in special_tokens_dict.items():
1365
            assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
1366

1367
            if self.verbose:
1368
                # logger.info(f"Assigning {value} to the {key} key of the tokenizer")
1369
                print(f"Assigning {value} to the {key} key of the tokenizer")
1370
            setattr(self, key, value)
1371

1372
            if key == "additional_special_tokens":
1373
                assert isinstance(value, (list, tuple)) and all(
1374
                    isinstance(t, (str, AddedToken)) for t in value
1375
                ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
1376
                added_tokens += self.add_tokens(value, special_tokens=True)
1377
            else:
1378
                assert isinstance(
1379
                    value, (str, AddedToken)
1380
                ), f"Token {value} for key {key} should be a str or an AddedToken instance"
1381
                added_tokens += self.add_tokens([value], special_tokens=True)
1382

1383
        return added_tokens
1384

1385
    def add_tokens(
1386
        self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False
1387
    ) -> int:
1388
        """
1389
        Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
1390
        it with indices starting from length of the current vocabulary.
1391

1392
        Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
1393
        matrix of the model so that its embedding matrix matches the tokenizer.
1394

1395
        In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
1396

1397
        Args:
1398
            new_tokens (`str`, `tokenizers.AddedToken` or a list of *str* or `tokenizers.AddedToken`):
1399
                Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string
1400
                token to let you personalize its behavior: whether this token should only match against a single word,
1401
                whether this token should strip all potential whitespaces on the left side, whether this token should
1402
                strip all potential whitespaces on the right side, etc.
1403
            special_tokens (`bool`, *optional*, defaults to `False`):
1404
                Can be used to specify if the token is a special token. This mostly change the normalization behavior
1405
                (special tokens like CLS or [MASK] are usually not lower-cased for instance).
1406

1407
                See details for `tokenizers.AddedToken` in HuggingFace tokenizers library.
1408

1409
        Returns:
1410
            `int`: Number of tokens added to the vocabulary.
1411

1412
        Examples:
1413

1414
        ```python
1415
        # Let's see how to increase the vocabulary of Bert model and tokenizer
1416
        tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
1417
        model = BertModel.from_pretrained("bert-base-uncased")
1418

1419
        num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
1420
        print("We have added", num_added_toks, "tokens")
1421
        # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
1422
        model.resize_token_embeddings(len(tokenizer))
1423
        ```"""
1424
        if not new_tokens:
1425
            return 0
1426

1427
        if not isinstance(new_tokens, (list, tuple)):
1428
            new_tokens = [new_tokens]
1429

1430
        return self._add_tokens(new_tokens, special_tokens=special_tokens)
1431

1432
    def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
1433
        new_tokens = [str(tok) for tok in new_tokens]
1434

1435
        tokens_to_add = []
1436
        for token in new_tokens:
1437
            if not isinstance(token, str):
1438
                raise TypeError(f"Token {token} is not a string but a {type(token)}.")
1439
            if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
1440
                token = token.lower()
1441
            if (
1442
                token != self.unk_token
1443
                and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
1444
                and token not in tokens_to_add
1445
            ):
1446
                tokens_to_add.append(token)
1447
                # if self.verbose:
1448
            # logger.info(f"Adding {token} to the vocabulary")
1449
            # print(f"Adding {token} to the vocabulary")
1450

1451
        added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
1452
        added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
1453
        self.added_tokens_encoder.update(added_tok_encoder)
1454
        self.added_tokens_decoder.update(added_tok_decoder)
1455

1456
        # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
1457
        if special_tokens:
1458
            if len(new_tokens) == 1:
1459
                _insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0])
1460
            else:
1461
                self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
1462
        else:
1463
            # Or on the newly added tokens
1464
            if len(tokens_to_add) == 1:
1465
                _insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
1466
            else:
1467
                self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
1468
        self._create_trie(self.unique_no_split_tokens)
1469

1470
        return len(tokens_to_add)
1471

1472
    def _create_trie(self, unique_no_split_tokens):
1473
        trie = Trie()
1474
        for token in unique_no_split_tokens:
1475
            if hasattr(self, "do_lower_case") and self.do_lower_case and token not in self.all_special_tokens:
1476
                trie.add(token.lower())
1477
            else:
1478
                trie.add(token)
1479
        self.tokens_trie = trie
1480

1481
    @property
1482
    def bos_token(self) -> str:
1483
        """
1484
        `str`: Beginning of sentence token. Log an error if used while not having been set.
1485
        """
1486
        if self._bos_token is None and self.verbose:
1487
            print("Using bos_token, but it is not set yet.")
1488
            # logger.error("Using bos_token, but it is not set yet.")
1489
            return None
1490
        return str(self._bos_token)
1491

1492
    @property
1493
    def eos_token(self) -> str:
1494
        """
1495
        `str`: End of sentence token. Log an error if used while not having been set.
1496
        """
1497
        if self._eos_token is None and self.verbose:
1498
            # logger.error("Using eos_token, but it is not set yet.")
1499
            print("Using eos_token, but it is not set yet.")
1500
            return None
1501
        return str(self._eos_token)
1502

1503
    @property
1504
    def unk_token(self) -> str:
1505
        """
1506
        `str`: Unknown token. Log an error if used while not having been set.
1507
        """
1508
        if self._unk_token is None and self.verbose:
1509
            print("Using unk_token, but it is not set yet.")
1510
            # logger.error("Using unk_token, but it is not set yet.")
1511
            return None
1512
        return str(self._unk_token)
1513

1514
    @property
1515
    def sep_token(self) -> str:
1516
        """
1517
        `str`: Separation token, to separate context and query in an input sequence. Log an error if used while not
1518
        having been set.
1519
        """
1520
        if self._sep_token is None and self.verbose:
1521
            print("Using sep_token, but it is not set yet.")
1522
            # logger.error("Using sep_token, but it is not set yet.")
1523
            return None
1524
        return str(self._sep_token)
1525

1526
    @property
1527
    def pad_token(self) -> str:
1528
        """
1529
        `str`: Padding token. Log an error if used while not having been set.
1530
        """
1531
        if self._pad_token is None and self.verbose:
1532
            # logger.error("Using pad_token, but it is not set yet.")
1533
            print("Using pad_token, but it is not set yet.")
1534
            return None
1535
        return str(self._pad_token)
1536

1537
    @property
1538
    def cls_token(self) -> str:
1539
        """
1540
        `str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full
1541
        depth of the model. Log an error if used while not having been set.
1542
        """
1543
        if self._cls_token is None and self.verbose:
1544
            # logger.error("Using cls_token, but it is not set yet.")
1545
            print("Using cls_token, but it is not set yet.")
1546
            return None
1547
        return str(self._cls_token)
1548

1549
    @property
1550
    def mask_token(self) -> str:
1551
        """
1552
        `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
1553
        having been set.
1554
        """
1555
        if self._mask_token is None and self.verbose:
1556
            # logger.error("Using mask_token, but it is not set yet.")
1557
            print("Using mask_token, but it is not set yet.")
1558
            return None
1559
        return str(self._mask_token)
1560

1561
    @property
1562
    def additional_special_tokens(self) -> List[str]:
1563
        """
1564
        `List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been
1565
        set.
1566
        """
1567
        if self._additional_special_tokens is None and self.verbose:
1568
            # logger.error("Using additional_special_tokens, but it is not set yet.")
1569
            print("Using additional_special_tokens, but it is not set yet.")
1570
            return None
1571
        return [str(tok) for tok in self._additional_special_tokens]
1572

1573
    @bos_token.setter
1574
    def bos_token(self, value):
1575
        self._bos_token = value
1576

1577
    @eos_token.setter
1578
    def eos_token(self, value):
1579
        self._eos_token = value
1580

1581
    @unk_token.setter
1582
    def unk_token(self, value):
1583
        self._unk_token = value
1584

1585
    @sep_token.setter
1586
    def sep_token(self, value):
1587
        self._sep_token = value
1588

1589
    @pad_token.setter
1590
    def pad_token(self, value):
1591
        self._pad_token = value
1592

1593
    @cls_token.setter
1594
    def cls_token(self, value):
1595
        self._cls_token = value
1596

1597
    @mask_token.setter
1598
    def mask_token(self, value):
1599
        self._mask_token = value
1600

1601
    @additional_special_tokens.setter
1602
    def additional_special_tokens(self, value):
1603
        self._additional_special_tokens = value
1604

1605
    @property
1606
    def bos_token_id(self) -> Optional[int]:
1607
        """
1608
        `Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns `None` if the token has not
1609
        been set.
1610
        """
1611
        if self._bos_token is None:
1612
            return None
1613
        return self.convert_tokens_to_ids(self.bos_token)
1614

1615
    @property
1616
    def eos_token_id(self) -> Optional[int]:
1617
        """
1618
        `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
1619
        set.
1620
        """
1621
        if self._eos_token is None:
1622
            return None
1623
        return self.convert_tokens_to_ids(self.eos_token)
1624

1625
    @property
1626
    def unk_token_id(self) -> Optional[int]:
1627
        """
1628
        `Optional[int]`: Id of the unknown token in the vocabulary. Returns `None` if the token has not been set.
1629
        """
1630
        if self._unk_token is None:
1631
            return None
1632
        return self.convert_tokens_to_ids(self.unk_token)
1633

1634
    @property
1635
    def sep_token_id(self) -> Optional[int]:
1636
        """
1637
        `Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input
1638
        sequence. Returns `None` if the token has not been set.
1639
        """
1640
        if self._sep_token is None:
1641
            return None
1642
        return self.convert_tokens_to_ids(self.sep_token)
1643

1644
    @property
1645
    def pad_token_id(self) -> Optional[int]:
1646
        """
1647
        `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
1648
        """
1649
        if self._pad_token is None:
1650
            return None
1651
        return self.convert_tokens_to_ids(self.pad_token)
1652

1653
    @property
1654
    def pad_token_type_id(self) -> int:
1655
        """
1656
        `int`: Id of the padding token type in the vocabulary.
1657
        """
1658
        return self._pad_token_type_id
1659

1660
    @property
1661
    def cls_token_id(self) -> Optional[int]:
1662
        """
1663
        `Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input sequence
1664
        leveraging self-attention along the full depth of the model.
1665

1666
        Returns `None` if the token has not been set.
1667
        """
1668
        if self._cls_token is None:
1669
            return None
1670
        return self.convert_tokens_to_ids(self.cls_token)
1671

1672
    @property
1673
    def mask_token_id(self) -> Optional[int]:
1674
        """
1675
        `Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language
1676
        modeling. Returns `None` if the token has not been set.
1677
        """
1678
        if self._mask_token is None:
1679
            return None
1680
        return self.convert_tokens_to_ids(self.mask_token)
1681

1682
    @property
1683
    def additional_special_tokens_ids(self) -> List[int]:
1684
        """
1685
        `List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not having
1686
        been set.
1687
        """
1688
        return self.convert_tokens_to_ids(self.additional_special_tokens)
1689

1690
    @bos_token_id.setter
1691
    def bos_token_id(self, value):
1692
        self._bos_token = self.convert_ids_to_tokens(value) if value is not None else None
1693

1694
    @eos_token_id.setter
1695
    def eos_token_id(self, value):
1696
        self._eos_token = self.convert_ids_to_tokens(value) if value is not None else None
1697

1698
    @unk_token_id.setter
1699
    def unk_token_id(self, value):
1700
        self._unk_token = self.convert_ids_to_tokens(value) if value is not None else None
1701

1702
    @sep_token_id.setter
1703
    def sep_token_id(self, value):
1704
        self._sep_token = self.convert_ids_to_tokens(value) if value is not None else None
1705

1706
    @pad_token_id.setter
1707
    def pad_token_id(self, value):
1708
        self._pad_token = self.convert_ids_to_tokens(value) if value is not None else None
1709

1710
    @cls_token_id.setter
1711
    def cls_token_id(self, value):
1712
        self._cls_token = self.convert_ids_to_tokens(value) if value is not None else None
1713

1714
    @mask_token_id.setter
1715
    def mask_token_id(self, value):
1716
        self._mask_token = self.convert_ids_to_tokens(value) if value is not None else None
1717

1718
    @additional_special_tokens_ids.setter
1719
    def additional_special_tokens_ids(self, values):
1720
        self._additional_special_tokens = [self.convert_ids_to_tokens(value) for value in values]
1721

1722
    @property
1723
    def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
1724
        """
1725
        `Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (`cls_token`,
1726
        `unk_token`, etc.) to their values (`'<unk>'`, `'<cls>'`, etc.).
1727

1728
        Convert potential tokens of `tokenizers.AddedToken` type to string.
1729
        """
1730
        set_attr = {}
1731
        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
1732
            attr_value = getattr(self, "_" + attr)
1733
            if attr_value:
1734
                set_attr[attr] = (
1735
                    type(attr_value)(str(attr_value_sub) for attr_value_sub in attr_value)
1736
                    if isinstance(attr_value, (list, tuple))
1737
                    else str(attr_value)
1738
                )
1739
        return set_attr
1740

1741
    @property
1742
    def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]:
1743
        """
1744
        `Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping
1745
        special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`'<unk>'`, `'<cls>'`, etc.).
1746

1747
        Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
1748
        special tokens are tokenized.
1749
        """
1750
        set_attr = {}
1751
        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
1752
            attr_value = getattr(self, "_" + attr)
1753
            if attr_value:
1754
                set_attr[attr] = attr_value
1755
        return set_attr
1756

1757
    @property
1758
    def all_special_tokens(self) -> List[str]:
1759
        """
1760
        `List[str]`: All the special tokens (`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
1761

1762
        Convert tokens of `tokenizers.AddedToken` type to string.
1763
        """
1764
        all_toks = [str(s) for s in self.all_special_tokens_extended]
1765
        return all_toks
1766

1767
    @property
1768
    def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
1769
        """
1770
        `List[Union[str, tokenizers.AddedToken]]`: All the special tokens (`'<unk>'`, `'<cls>'`, etc.) mapped to class
1771
        attributes.
1772

1773
        Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
1774
        special tokens are tokenized.
1775
        """
1776
        all_toks = []
1777
        set_attr = self.special_tokens_map_extended
1778
        for attr_value in set_attr.values():
1779
            all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
1780
        all_toks = list(OrderedDict.fromkeys(all_toks))
1781
        return all_toks
1782

1783
    @property
1784
    def all_special_ids(self) -> List[int]:
1785
        """
1786
        `List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
1787
        """
1788
        all_toks = self.all_special_tokens
1789
        all_ids = self.convert_tokens_to_ids(all_toks)
1790
        return all_ids
1791

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.