CSS-LM
180 строк · 6.6 Кб
1# coding=utf-8
2# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Tokenization class for model Reformer."""
16
17
18import logging
19import os
20from shutil import copyfile
21
22from .tokenization_utils import PreTrainedTokenizer
23
24
25logger = logging.getLogger(__name__)
26
27SPIECE_UNDERLINE = "▁"
28
29
30####################################################
31# Mapping from the keyword arguments names of Tokenizer `__init__`
32# to file names for serializing Tokenizer instances
33####################################################
34VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
35
36####################################################
37# Mapping from the keyword arguments names of Tokenizer `__init__`
38# to pretrained vocabulary URL for all the model shortcut names.
39####################################################
40PRETRAINED_VOCAB_FILES_MAP = {
41"vocab_file": {
42"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model"
43}
44}
45
46####################################################
47# Mapping from model shortcut names to max length of inputs
48####################################################
49PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
50"google/reformer-crime-and-punishment": 524288,
51}
52
53
54class ReformerTokenizer(PreTrainedTokenizer):
55"""
56Constructs an Reformer tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
57
58This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
59should refer to the superclass for more information regarding methods.
60
61Args:
62vocab_file (:obj:`string`):
63`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
64contains the vocabulary necessary to instantiate a tokenizer.
65eos_token (:obj:`string`, `optional`, defaults to "</s>"):
66The end of sequence token.
67
68.. note::
69
70When building a sequence using special tokens, this is not the token that is used for the end
71of sequence. The token used is the :obj:`sep_token`.
72unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
73The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
74token instead.
75pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
76The token used for padding, for example when batching sequences of different lengths.
77additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
78Additional special tokens used by the tokenizer.
79"""
80
81vocab_files_names = VOCAB_FILES_NAMES
82pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
83max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
84model_input_names = ["attention_mask"]
85
86def __init__(
87self,
88vocab_file,
89eos_token="</s>",
90unk_token="<unk>",
91pad_token="<pad>",
92additional_special_tokens=[],
93**kwargs
94):
95super().__init__(
96eos_token=eos_token,
97unk_token=unk_token,
98pad_token=pad_token,
99additional_special_tokens=additional_special_tokens,
100**kwargs,
101)
102
103try:
104import sentencepiece as spm
105except ImportError:
106logger.warning(
107"You need to install SentencePiece to use ReformerTokenizer:"
108"https://github.com/google/sentencepiece"
109"pip install sentencepiece"
110)
111raise
112
113self.vocab_file = vocab_file
114self.sp_model = spm.SentencePieceProcessor()
115self.sp_model.Load(vocab_file)
116
117@property
118def vocab_size(self):
119return self.sp_model.get_piece_size()
120
121def get_vocab(self):
122vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
123vocab.update(self.added_tokens_encoder)
124return vocab
125
126def __getstate__(self):
127state = self.__dict__.copy()
128state["sp_model"] = None
129return state
130
131def __setstate__(self, d):
132self.__dict__ = d
133try:
134import sentencepiece as spm
135except ImportError:
136logger.warning(
137"You need to install SentencePiece to use ReformerTokenizer: https://github.com/google/sentencepiece"
138"pip install sentencepiece"
139)
140raise
141self.sp_model = spm.SentencePieceProcessor()
142self.sp_model.Load(self.vocab_file)
143
144def _tokenize(self, text, sample=False):
145""" Take as input a string and return a list of strings (tokens) for words/sub-words
146"""
147if not sample:
148pieces = self.sp_model.EncodeAsPieces(text)
149else:
150pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
151return pieces
152
153def _convert_token_to_id(self, token):
154""" Converts a token (str) in an id using the vocab. """
155return self.sp_model.piece_to_id(token)
156
157def _convert_id_to_token(self, index):
158"""Converts an index (integer) in a token (str) using the vocab."""
159if index < self.sp_model.get_piece_size():
160token = self.sp_model.IdToPiece(index)
161return token
162
163def convert_tokens_to_string(self, tokens):
164""" Converts a sequence of tokens (string) in a single string. """
165out_string = self.sp_model.decode_pieces(tokens)
166return out_string
167
168def save_vocabulary(self, save_directory):
169""" Save the sentencepiece vocabulary (copy original file) and special tokens file
170to a directory.
171"""
172if not os.path.isdir(save_directory):
173logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
174return
175out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
176
177if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
178copyfile(self.vocab_file, out_vocab_file)
179
180return (out_vocab_file,)
181