CSS-LM
258 строк · 10.4 Кб
1# coding=utf-8
2# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17from typing import List, Optional
18
19from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
20from .tokenization_utils import BatchEncoding
21from .tokenization_xlm_roberta import XLMRobertaTokenizer
22
23
24logger = logging.getLogger(__name__)
25
26
27# vocab and merges same as roberta
28vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
29merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
30_all_bart_models = [
31"facebook/bart-base",
32"facebook/bart-large",
33"facebook/bart-large-mnli",
34"facebook/bart-large-cnn",
35"facebook/bart-large-xsum",
36"yjernite/bart_eli5",
37]
38
39
40class BartTokenizer(RobertaTokenizer):
41# merges and vocab same as Roberta
42max_model_input_sizes = {m: 1024 for m in _all_bart_models}
43pretrained_vocab_files_map = {
44"vocab_file": {m: vocab_url for m in _all_bart_models},
45"merges_file": {m: merges_url for m in _all_bart_models},
46}
47
48
49class BartTokenizerFast(RobertaTokenizerFast):
50# merges and vocab same as Roberta
51max_model_input_sizes = {m: 1024 for m in _all_bart_models}
52pretrained_vocab_files_map = {
53"vocab_file": {m: vocab_url for m in _all_bart_models},
54"merges_file": {m: merges_url for m in _all_bart_models},
55}
56
57
58_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
59SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
60
61FAIRSEQ_LANGUAGE_CODES = [
62"ar_AR",
63"cs_CZ",
64"de_DE",
65"en_XX",
66"es_XX",
67"et_EE",
68"fi_FI",
69"fr_XX",
70"gu_IN",
71"hi_IN",
72"it_IT",
73"ja_XX",
74"kk_KZ",
75"ko_KR",
76"lt_LT",
77"lv_LV",
78"my_MM",
79"ne_NP",
80"nl_XX",
81"ro_RO",
82"ru_RU",
83"si_LK",
84"tr_TR",
85"vi_VN",
86"zh_CN",
87]
88
89
90class MBartTokenizer(XLMRobertaTokenizer):
91"""
92This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
93Other tokenizer methods like ``encode`` do not work properly.
94The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
95``<language code> <tokens> <eos>``` for target language documents.
96
97Examples::
98
99>>> from transformers import MBartTokenizer
100>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
101>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
102>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
103>>> batch: dict = tokenizer.prepare_translation_batch(
104... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
105... )
106
107"""
108
109vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
110max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
111pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
112
113prefix_tokens: List[int] = []
114suffix_tokens: List[int] = []
115
116def __init__(self, *args, **kwargs):
117super().__init__(*args, **kwargs)
118
119self.sp_model_size = len(self.sp_model)
120self.lang_code_to_id = {
121code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
122}
123self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
124self.cur_lang_code = self.lang_code_to_id["en_XX"]
125self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
126
127self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
128self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
129self._additional_special_tokens = list(self.lang_code_to_id.keys())
130self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
131
132def build_inputs_with_special_tokens(
133self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
134) -> List[int]:
135"""
136Build model inputs from a sequence or a pair of sequence for sequence classification tasks
137by concatenating and adding special tokens. The special tokens depend on calling set_lang.
138An MBART sequence has the following format, where ``X`` represents the sequence:
139- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
140- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
141BOS is never used.
142Pairs of sequences are not the expected use case, but they will be handled without a separator.
143
144Args:
145token_ids_0 (:obj:`List[int]`):
146List of IDs to which the special tokens will be added
147token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
148Optional second list of IDs for sequence pairs.
149
150Returns:
151:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
152"""
153if token_ids_1 is None:
154return self.prefix_tokens + token_ids_0 + self.suffix_tokens
155# We don't expect to process pairs, but leave the pair logic for API consistency
156return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
157
158def get_special_tokens_mask(
159self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
160) -> List[int]:
161"""
162Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
163special tokens using the tokenizer ``prepare_for_model`` methods.
164
165Args:
166token_ids_0 (:obj:`List[int]`):
167List of ids.
168token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
169Optional second list of IDs for sequence pairs.
170already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
171Set to True if the token list is already formatted with special tokens for the model
172
173Returns:
174:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
175"""
176
177if already_has_special_tokens:
178if token_ids_1 is not None:
179raise ValueError(
180"You should not supply a second sequence if the provided sequence of "
181"ids is already formated with special tokens for the model."
182)
183return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
184prefix_ones = [1] * len(self.prefix_tokens)
185suffix_ones = [1] * len(self.suffix_tokens)
186if token_ids_1 is None:
187return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
188return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
189
190def prepare_translation_batch(
191self,
192src_texts: List[str],
193src_lang: str = "en_XX",
194tgt_texts: Optional[List[str]] = None,
195tgt_lang: str = "ro_RO",
196max_length: Optional[int] = None,
197max_target_length: Optional[int] = None,
198padding: str = "longest",
199return_tensors: str = "pt",
200**kwargs,
201) -> BatchEncoding:
202"""Prepare a batch that can be passed directly to an instance of MBartModel.
203Arguments:
204src_texts: list of src language texts
205src_lang: default en_XX (english), the language we are translating from
206tgt_texts: list of tgt language texts
207tgt_lang: default ro_RO (romanian), the language we are translating to
208max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
209padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
210**kwargs: passed to self.__call__
211
212Returns:
213:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
214"""
215if max_length is None:
216max_length = self.max_len
217self.set_src_lang_special_tokens(src_lang)
218model_inputs: BatchEncoding = self(
219src_texts,
220add_special_tokens=True,
221return_tensors=return_tensors,
222max_length=max_length,
223padding=padding,
224truncation=True,
225**kwargs,
226)
227if tgt_texts is None:
228return model_inputs
229# Process tgt_texts
230if max_target_length is None:
231max_target_length = max_length
232self.set_tgt_lang_special_tokens(tgt_lang)
233decoder_inputs: BatchEncoding = self(
234tgt_texts,
235add_special_tokens=True,
236return_tensors=return_tensors,
237padding=padding,
238max_length=max_target_length,
239truncation=True,
240**kwargs,
241)
242for k, v in decoder_inputs.items():
243model_inputs[f"decoder_{k}"] = v
244
245self.set_src_lang_special_tokens(src_lang) # sets to src_lang
246return model_inputs
247
248def set_src_lang_special_tokens(self, src_lang) -> None:
249"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
250self.cur_lang_code = self.lang_code_to_id[src_lang]
251self.prefix_tokens = []
252self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
253
254def set_tgt_lang_special_tokens(self, lang: str) -> None:
255"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
256self.cur_lang_code = self.lang_code_to_id[lang]
257self.prefix_tokens = [self.cur_lang_code]
258self.suffix_tokens = [self.eos_token_id]
259