CSS-LM

Форк
0
/
tokenization_reformer.py 
180 строк · 6.6 Кб
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Tokenization class for model Reformer."""
16

17

18
import logging
19
import os
20
from shutil import copyfile
21

22
from .tokenization_utils import PreTrainedTokenizer
23

24

25
logger = logging.getLogger(__name__)
26

27
SPIECE_UNDERLINE = "▁"
28

29

30
####################################################
31
# Mapping from the keyword arguments names of Tokenizer `__init__`
32
# to file names for serializing Tokenizer instances
33
####################################################
34
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
35

36
####################################################
37
# Mapping from the keyword arguments names of Tokenizer `__init__`
38
# to pretrained vocabulary URL for all the model shortcut names.
39
####################################################
40
PRETRAINED_VOCAB_FILES_MAP = {
41
    "vocab_file": {
42
        "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model"
43
    }
44
}
45

46
####################################################
47
# Mapping from model shortcut names to max length of inputs
48
####################################################
49
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
50
    "google/reformer-crime-and-punishment": 524288,
51
}
52

53

54
class ReformerTokenizer(PreTrainedTokenizer):
55
    """
56
        Constructs an Reformer tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
57

58
        This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
59
        should refer to the superclass for more information regarding methods.
60

61
        Args:
62
            vocab_file (:obj:`string`):
63
                `SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
64
                contains the vocabulary necessary to instantiate a tokenizer.
65
            eos_token (:obj:`string`, `optional`, defaults to "</s>"):
66
                The end of sequence token.
67

68
                .. note::
69

70
                    When building a sequence using special tokens, this is not the token that is used for the end
71
                    of sequence. The token used is the :obj:`sep_token`.
72
            unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
73
                The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
74
                token instead.
75
            pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
76
                The token used for padding, for example when batching sequences of different lengths.
77
            additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
78
                Additional special tokens used by the tokenizer.
79
    """
80

81
    vocab_files_names = VOCAB_FILES_NAMES
82
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
83
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
84
    model_input_names = ["attention_mask"]
85

86
    def __init__(
87
        self,
88
        vocab_file,
89
        eos_token="</s>",
90
        unk_token="<unk>",
91
        pad_token="<pad>",
92
        additional_special_tokens=[],
93
        **kwargs
94
    ):
95
        super().__init__(
96
            eos_token=eos_token,
97
            unk_token=unk_token,
98
            pad_token=pad_token,
99
            additional_special_tokens=additional_special_tokens,
100
            **kwargs,
101
        )
102

103
        try:
104
            import sentencepiece as spm
105
        except ImportError:
106
            logger.warning(
107
                "You need to install SentencePiece to use ReformerTokenizer:"
108
                "https://github.com/google/sentencepiece"
109
                "pip install sentencepiece"
110
            )
111
            raise
112

113
        self.vocab_file = vocab_file
114
        self.sp_model = spm.SentencePieceProcessor()
115
        self.sp_model.Load(vocab_file)
116

117
    @property
118
    def vocab_size(self):
119
        return self.sp_model.get_piece_size()
120

121
    def get_vocab(self):
122
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
123
        vocab.update(self.added_tokens_encoder)
124
        return vocab
125

126
    def __getstate__(self):
127
        state = self.__dict__.copy()
128
        state["sp_model"] = None
129
        return state
130

131
    def __setstate__(self, d):
132
        self.__dict__ = d
133
        try:
134
            import sentencepiece as spm
135
        except ImportError:
136
            logger.warning(
137
                "You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece"
138
                "pip install sentencepiece"
139
            )
140
            raise
141
        self.sp_model = spm.SentencePieceProcessor()
142
        self.sp_model.Load(self.vocab_file)
143

144
    def _tokenize(self, text, sample=False):
145
        """ Take as input a string and return a list of strings (tokens) for words/sub-words
146
        """
147
        if not sample:
148
            pieces = self.sp_model.EncodeAsPieces(text)
149
        else:
150
            pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
151
        return pieces
152

153
    def _convert_token_to_id(self, token):
154
        """ Converts a token (str) in an id using the vocab. """
155
        return self.sp_model.piece_to_id(token)
156

157
    def _convert_id_to_token(self, index):
158
        """Converts an index (integer) in a token (str) using the vocab."""
159
        if index < self.sp_model.get_piece_size():
160
            token = self.sp_model.IdToPiece(index)
161
        return token
162

163
    def convert_tokens_to_string(self, tokens):
164
        """ Converts a sequence of tokens (string) in a single string. """
165
        out_string = self.sp_model.decode_pieces(tokens)
166
        return out_string
167

168
    def save_vocabulary(self, save_directory):
169
        """ Save the sentencepiece vocabulary (copy original file) and special tokens file
170
            to a directory.
171
        """
172
        if not os.path.isdir(save_directory):
173
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
174
            return
175
        out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
176

177
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
178
            copyfile(self.vocab_file, out_vocab_file)
179

180
        return (out_vocab_file,)
181

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.