CSS-LM
343 строки · 13.6 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Tokenization classes for XLNet model."""
16
17
18import logging
19import os
20import unicodedata
21from shutil import copyfile
22from typing import List, Optional
23
24from .tokenization_utils import PreTrainedTokenizer
25
26
27logger = logging.getLogger(__name__)
28
29VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
30
31PRETRAINED_VOCAB_FILES_MAP = {
32"vocab_file": {
33"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
34"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
35}
36}
37
38PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
39"xlnet-base-cased": None,
40"xlnet-large-cased": None,
41}
42
43SPIECE_UNDERLINE = "▁"
44
45# Segments (not really needed)
46SEG_ID_A = 0
47SEG_ID_B = 1
48SEG_ID_CLS = 2
49SEG_ID_SEP = 3
50SEG_ID_PAD = 4
51
52
53class XLNetTokenizer(PreTrainedTokenizer):
54"""
55Constructs an XLNet tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__
56
57This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
58should refer to the superclass for more information regarding methods.
59
60Args:
61vocab_file (:obj:`string`):
62`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a .spm extension) that
63contains the vocabulary necessary to instantiate a tokenizer.
64do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
65Whether to lowercase the input when tokenizing.
66remove_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
67Whether to strip the text when tokenizing (removing excess spaces before and after the string).
68keep_accents (:obj:`bool`, `optional`, defaults to :obj:`False`):
69Whether to keep accents when tokenizing.
70bos_token (:obj:`string`, `optional`, defaults to "<s>"):
71The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token.
72
73.. note::
74
75When building a sequence using special tokens, this is not the token that is used for the beginning
76of sequence. The token used is the :obj:`cls_token`.
77eos_token (:obj:`string`, `optional`, defaults to "</s>"):
78The end of sequence token.
79
80.. note::
81
82When building a sequence using special tokens, this is not the token that is used for the end
83of sequence. The token used is the :obj:`sep_token`.
84unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
85The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
86token instead.
87sep_token (:obj:`string`, `optional`, defaults to "<sep>"):
88The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
89for sequence classification or for a text and a question for question answering.
90It is also used as the last token of a sequence built with special tokens.
91pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
92The token used for padding, for example when batching sequences of different lengths.
93cls_token (:obj:`string`, `optional`, defaults to "<cls>"):
94The classifier token which is used when doing sequence classification (classification of the whole
95sequence instead of per-token classification). It is the first token of the sequence when built with
96special tokens.
97mask_token (:obj:`string`, `optional`, defaults to "<mask>"):
98The token used for masking values. This is the token used when training this model with masked language
99modeling. This is the token which the model will try to predict.
100additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<eop>", "<eod>"]`):
101Additional special tokens used by the tokenizer.
102
103Attributes:
104sp_model (:obj:`SentencePieceProcessor`):
105The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
106"""
107
108vocab_files_names = VOCAB_FILES_NAMES
109pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
110max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
111padding_side = "left"
112
113def __init__(
114self,
115vocab_file,
116do_lower_case=False,
117remove_space=True,
118keep_accents=False,
119bos_token="<s>",
120eos_token="</s>",
121unk_token="<unk>",
122sep_token="<sep>",
123pad_token="<pad>",
124cls_token="<cls>",
125mask_token="<mask>",
126additional_special_tokens=["<eop>", "<eod>"],
127**kwargs
128):
129super().__init__(
130bos_token=bos_token,
131eos_token=eos_token,
132unk_token=unk_token,
133sep_token=sep_token,
134pad_token=pad_token,
135cls_token=cls_token,
136mask_token=mask_token,
137additional_special_tokens=additional_special_tokens,
138**kwargs,
139)
140
141self._pad_token_type_id = 3
142
143try:
144import sentencepiece as spm
145except ImportError:
146logger.warning(
147"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
148"pip install sentencepiece"
149)
150raise
151
152self.do_lower_case = do_lower_case
153self.remove_space = remove_space
154self.keep_accents = keep_accents
155self.vocab_file = vocab_file
156
157self.sp_model = spm.SentencePieceProcessor()
158self.sp_model.Load(vocab_file)
159
160@property
161def vocab_size(self):
162return len(self.sp_model)
163
164def get_vocab(self):
165vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
166vocab.update(self.added_tokens_encoder)
167return vocab
168
169def __getstate__(self):
170state = self.__dict__.copy()
171state["sp_model"] = None
172return state
173
174def __setstate__(self, d):
175self.__dict__ = d
176try:
177import sentencepiece as spm
178except ImportError:
179logger.warning(
180"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
181"pip install sentencepiece"
182)
183raise
184self.sp_model = spm.SentencePieceProcessor()
185self.sp_model.Load(self.vocab_file)
186
187def preprocess_text(self, inputs):
188if self.remove_space:
189outputs = " ".join(inputs.strip().split())
190else:
191outputs = inputs
192outputs = outputs.replace("``", '"').replace("''", '"')
193
194if not self.keep_accents:
195outputs = unicodedata.normalize("NFKD", outputs)
196outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
197if self.do_lower_case:
198outputs = outputs.lower()
199
200return outputs
201
202def _tokenize(self, text, sample=False):
203""" Tokenize a string. """
204text = self.preprocess_text(text)
205
206if not sample:
207pieces = self.sp_model.EncodeAsPieces(text)
208else:
209pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
210new_pieces = []
211for piece in pieces:
212if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
213cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
214if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
215if len(cur_pieces[0]) == 1:
216cur_pieces = cur_pieces[1:]
217else:
218cur_pieces[0] = cur_pieces[0][1:]
219cur_pieces.append(piece[-1])
220new_pieces.extend(cur_pieces)
221else:
222new_pieces.append(piece)
223
224return new_pieces
225
226def _convert_token_to_id(self, token):
227""" Converts a token (str) in an id using the vocab. """
228return self.sp_model.PieceToId(token)
229
230def _convert_id_to_token(self, index):
231"""Converts an index (integer) in a token (str) using the vocab."""
232return self.sp_model.IdToPiece(index)
233
234def convert_tokens_to_string(self, tokens):
235"""Converts a sequence of tokens (strings for sub-words) in a single string."""
236out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
237return out_string
238
239def build_inputs_with_special_tokens(
240self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
241) -> List[int]:
242"""
243Build model inputs from a sequence or a pair of sequence for sequence classification tasks
244by concatenating and adding special tokens.
245An XLNet sequence has the following format:
246
247- single sequence: ``X <sep> <cls>``
248- pair of sequences: ``A <sep> B <sep> <cls>``
249
250Args:
251token_ids_0 (:obj:`List[int]`):
252List of IDs to which the special tokens will be added
253token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
254Optional second list of IDs for sequence pairs.
255
256Returns:
257:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
258"""
259sep = [self.sep_token_id]
260cls = [self.cls_token_id]
261if token_ids_1 is None:
262return token_ids_0 + sep + cls
263return token_ids_0 + sep + token_ids_1 + sep + cls
264
265def get_special_tokens_mask(
266self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
267) -> List[int]:
268"""
269Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
270special tokens using the tokenizer ``prepare_for_model`` methods.
271
272Args:
273token_ids_0 (:obj:`List[int]`):
274List of ids.
275token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
276Optional second list of IDs for sequence pairs.
277already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
278Set to True if the token list is already formatted with special tokens for the model
279
280Returns:
281:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
282"""
283
284if already_has_special_tokens:
285if token_ids_1 is not None:
286raise ValueError(
287"You should not supply a second sequence if the provided sequence of "
288"ids is already formated with special tokens for the model."
289)
290return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
291
292if token_ids_1 is not None:
293return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
294return ([0] * len(token_ids_0)) + [1, 1]
295
296def create_token_type_ids_from_sequences(
297self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
298) -> List[int]:
299"""
300Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
301An XLNet sequence pair mask has the following format:
3020 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2
303| first sequence | second sequence | CLS segment ID
304
305if token_ids_1 is None, only returns the first portion of the mask (0's).
306
307Args:
308token_ids_0 (:obj:`List[int]`):
309List of ids.
310token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
311Optional second list of IDs for sequence pairs.
312
313Returns:
314:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
315sequence(s).
316"""
317sep = [self.sep_token_id]
318cls_segment_id = [2]
319
320if token_ids_1 is None:
321return len(token_ids_0 + sep) * [0] + cls_segment_id
322return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
323
324def save_vocabulary(self, save_directory):
325"""
326Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
327
328Args:
329save_directory (:obj:`str`):
330The directory in which to save the vocabulary.
331
332Returns:
333:obj:`Tuple(str)`: Paths to the files saved.
334"""
335if not os.path.isdir(save_directory):
336logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
337return
338out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
339
340if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
341copyfile(self.vocab_file, out_vocab_file)
342
343return (out_vocab_file,)
344