CSS-LM
208 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2018 T5 Authors and HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Tokenization class for model T5."""
16
17
18import logging
19import os
20import re
21from shutil import copyfile
22
23from .tokenization_utils import PreTrainedTokenizer
24
25
26logger = logging.getLogger(__name__)
27
28SPIECE_UNDERLINE = "▁"
29
30####################################################
31# Mapping from the keyword arguments names of Tokenizer `__init__`
32# to file names for serializing Tokenizer instances
33####################################################
34VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
35
36####################################################
37# Mapping from the keyword arguments names of Tokenizer `__init__`
38# to pretrained vocabulary URL for all the model shortcut names.
39####################################################
40PRETRAINED_VOCAB_FILES_MAP = {
41"vocab_file": {
42"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
43"t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
44"t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
45"t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
46"t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
47}
48}
49
50####################################################
51# Mapping from model shortcut names to max length of inputs
52####################################################
53PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
54"t5-small": 512,
55"t5-base": 512,
56"t5-large": 512,
57"t5-3b": 512,
58"t5-11b": 512,
59}
60
61
62class T5Tokenizer(PreTrainedTokenizer):
63"""
64Constructs a T5 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
65
66This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
67should refer to the superclass for more information regarding methods.
68
69Args:
70vocab_file (:obj:`string`):
71`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
72contains the vocabulary necessary to instantiate a tokenizer.
73eos_token (:obj:`string`, `optional`, defaults to "</s>"):
74The end of sequence token.
75
76.. note::
77
78When building a sequence using special tokens, this is not the token that is used for the end
79of sequence. The token used is the :obj:`sep_token`.
80unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
81The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
82token instead.
83pad_token (:obj:`string`, `optional`, defaults to "<pad>"):
84The token used for padding, for example when batching sequences of different lengths.
85extra_ids (:obj:`List[str]`, `optional`, defaults to :obj:`100`):
86Add a number of extra ids added to the end of the vocabulary for use as sentinels.
87These tokens are accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1.
88Extra tokens are indexed from the end of the vocabulary up to beginnning ("<extra_id_0>" is the last token in the vocabulary like in T5 preprocessing
89see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
90additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
91Additional special tokens used by the tokenizer.
92"""
93
94vocab_files_names = VOCAB_FILES_NAMES
95pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
97model_input_names = ["attention_mask"]
98
99def __init__(
100self,
101vocab_file,
102eos_token="</s>",
103unk_token="<unk>",
104pad_token="<pad>",
105extra_ids=100,
106additional_special_tokens=None,
107**kwargs
108):
109# Add extra_ids to the special token list
110if extra_ids > 0:
111if additional_special_tokens is None:
112additional_special_tokens = []
113additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
114
115super().__init__(
116eos_token=eos_token,
117unk_token=unk_token,
118pad_token=pad_token,
119additional_special_tokens=additional_special_tokens,
120**kwargs,
121)
122
123try:
124import sentencepiece as spm
125except ImportError:
126logger.warning(
127"You need to install SentencePiece to use T5Tokenizer:"
128"https://github.com/google/sentencepiece"
129"pip install sentencepiece"
130)
131raise
132
133self.vocab_file = vocab_file
134self._extra_ids = extra_ids
135
136self.sp_model = spm.SentencePieceProcessor()
137self.sp_model.Load(vocab_file)
138
139@property
140def vocab_size(self):
141return self.sp_model.get_piece_size() + self._extra_ids
142
143def get_vocab(self):
144vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
145vocab.update(self.added_tokens_encoder)
146return vocab
147
148def __getstate__(self):
149state = self.__dict__.copy()
150state["sp_model"] = None
151return state
152
153def __setstate__(self, d):
154self.__dict__ = d
155try:
156import sentencepiece as spm
157except ImportError:
158logger.warning(
159"You need to install SentencePiece to use T5Tokenizer: https://github.com/google/sentencepiece"
160"pip install sentencepiece"
161)
162raise
163self.sp_model = spm.SentencePieceProcessor()
164self.sp_model.Load(self.vocab_file)
165
166def _tokenize(self, text, sample=False):
167""" Take as input a string and return a list of strings (tokens) for words/sub-words
168"""
169if not sample:
170pieces = self.sp_model.EncodeAsPieces(text)
171else:
172pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
173return pieces
174
175def _convert_token_to_id(self, token):
176""" Converts a token (str) in an id using the vocab. """
177if token.startswith("<extra_id_"):
178match = re.match(r"<extra_id_(\d+)>", token)
179num = int(match.group(1))
180return self.vocab_size - num - 1
181return self.sp_model.piece_to_id(token)
182
183def _convert_id_to_token(self, index):
184"""Converts an index (integer) in a token (str) using the vocab."""
185if index < self.sp_model.get_piece_size():
186token = self.sp_model.IdToPiece(index)
187else:
188token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
189return token
190
191def convert_tokens_to_string(self, tokens):
192""" Converts a sequence of tokens (string) in a single string. """
193out_string = self.sp_model.decode_pieces(tokens)
194return out_string
195
196def save_vocabulary(self, save_directory):
197""" Save the sentencepiece vocabulary (copy original file) and special tokens file
198to a directory.
199"""
200if not os.path.isdir(save_directory):
201logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
202return
203out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
204
205if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
206copyfile(self.vocab_file, out_vocab_file)
207
208return (out_vocab_file,)
209