CSS-LM

Форк
0
/
tokenization_t5.py 
208 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Tokenization class for model T5."""
16

17

18
import logging
19
import os
20
import re
21
from shutil import copyfile
22

23
from .tokenization_utils import PreTrainedTokenizer
24

25

26
logger = logging.getLogger(__name__)
27

28
SPIECE_UNDERLINE = "▁"
29

30
####################################################
31
# Mapping from the keyword arguments names of Tokenizer `__init__`
32
# to file names for serializing Tokenizer instances
33
####################################################
34
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
35

36
####################################################
37
# Mapping from the keyword arguments names of Tokenizer `__init__`
38
# to pretrained vocabulary URL for all the model shortcut names.
39
####################################################
40
PRETRAINED_VOCAB_FILES_MAP = {
41
    "vocab_file": {
42
        "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
43
        "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
44
        "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
45
        "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
46
        "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
47
    }
48
}
49

50
####################################################
51
# Mapping from model shortcut names to max length of inputs
52
####################################################
53
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
54
    "t5-small": 512,
55
    "t5-base": 512,
56
    "t5-large": 512,
57
    "t5-3b": 512,
58
    "t5-11b": 512,
59
}
60

61

62
class T5Tokenizer(PreTrainedTokenizer):
63
    """
64
        Constructs a T5 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
65

66
        This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
67
        should refer to the superclass for more information regarding methods.
68

69
        Args:
70
            vocab_file (:obj:`string`):
71
                `SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
72
                contains the vocabulary necessary to instantiate a tokenizer.
73
            eos_token (:obj:`string`, `optional`, defaults to "</s>"):
74
                The end of sequence token.
75

76
                .. note::
77

78
                    When building a sequence using special tokens, this is not the token that is used for the end
79
                    of sequence. The token used is the :obj:`sep_token`.
80
            unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
81
                The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
82
                token instead.
83
            pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
84
                The token used for padding, for example when batching sequences of different lengths.
85
            extra_ids (:obj:`List[str]`, `optional`, defaults to :obj:`100`):
86
                Add a number of extra ids added to the end of the vocabulary for use as sentinels.
87
                These tokens are accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1.
88
                Extra tokens are indexed from the end of the vocabulary up to beginnning ("<extra_id_0>" is the last token in the vocabulary like in T5 preprocessing
89
                see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
90
            additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
91
                Additional special tokens used by the tokenizer.
92
    """
93

94
    vocab_files_names = VOCAB_FILES_NAMES
95
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
97
    model_input_names = ["attention_mask"]
98

99
    def __init__(
100
        self,
101
        vocab_file,
102
        eos_token="</s>",
103
        unk_token="<unk>",
104
        pad_token="<pad>",
105
        extra_ids=100,
106
        additional_special_tokens=None,
107
        **kwargs
108
    ):
109
        # Add extra_ids to the special token list
110
        if extra_ids > 0:
111
            if additional_special_tokens is None:
112
                additional_special_tokens = []
113
            additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
114

115
        super().__init__(
116
            eos_token=eos_token,
117
            unk_token=unk_token,
118
            pad_token=pad_token,
119
            additional_special_tokens=additional_special_tokens,
120
            **kwargs,
121
        )
122

123
        try:
124
            import sentencepiece as spm
125
        except ImportError:
126
            logger.warning(
127
                "You need to install SentencePiece to use T5Tokenizer:"
128
                "https://github.com/google/sentencepiece"
129
                "pip install sentencepiece"
130
            )
131
            raise
132

133
        self.vocab_file = vocab_file
134
        self._extra_ids = extra_ids
135

136
        self.sp_model = spm.SentencePieceProcessor()
137
        self.sp_model.Load(vocab_file)
138

139
    @property
140
    def vocab_size(self):
141
        return self.sp_model.get_piece_size() + self._extra_ids
142

143
    def get_vocab(self):
144
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
145
        vocab.update(self.added_tokens_encoder)
146
        return vocab
147

148
    def __getstate__(self):
149
        state = self.__dict__.copy()
150
        state["sp_model"] = None
151
        return state
152

153
    def __setstate__(self, d):
154
        self.__dict__ = d
155
        try:
156
            import sentencepiece as spm
157
        except ImportError:
158
            logger.warning(
159
                "You need to install SentencePiece to use T5Tokenizer: https://github.com/google/sentencepiece"
160
                "pip install sentencepiece"
161
            )
162
            raise
163
        self.sp_model = spm.SentencePieceProcessor()
164
        self.sp_model.Load(self.vocab_file)
165

166
    def _tokenize(self, text, sample=False):
167
        """ Take as input a string and return a list of strings (tokens) for words/sub-words
168
        """
169
        if not sample:
170
            pieces = self.sp_model.EncodeAsPieces(text)
171
        else:
172
            pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
173
        return pieces
174

175
    def _convert_token_to_id(self, token):
176
        """ Converts a token (str) in an id using the vocab. """
177
        if token.startswith("<extra_id_"):
178
            match = re.match(r"<extra_id_(\d+)>", token)
179
            num = int(match.group(1))
180
            return self.vocab_size - num - 1
181
        return self.sp_model.piece_to_id(token)
182

183
    def _convert_id_to_token(self, index):
184
        """Converts an index (integer) in a token (str) using the vocab."""
185
        if index < self.sp_model.get_piece_size():
186
            token = self.sp_model.IdToPiece(index)
187
        else:
188
            token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
189
        return token
190

191
    def convert_tokens_to_string(self, tokens):
192
        """ Converts a sequence of tokens (string) in a single string. """
193
        out_string = self.sp_model.decode_pieces(tokens)
194
        return out_string
195

196
    def save_vocabulary(self, save_directory):
197
        """ Save the sentencepiece vocabulary (copy original file) and special tokens file
198
            to a directory.
199
        """
200
        if not os.path.isdir(save_directory):
201
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
202
            return
203
        out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
204

205
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
206
            copyfile(self.vocab_file, out_vocab_file)
207

208
        return (out_vocab_file,)
209

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.