CSS-LM
239 строк · 9.3 Кб
1import json
2import re
3import warnings
4from pathlib import Path
5from shutil import copyfile
6from typing import Dict, List, Optional, Tuple, Union
7
8import sentencepiece
9
10from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
11
12
13vocab_files_names = {
14"source_spm": "source.spm",
15"target_spm": "target.spm",
16"vocab": "vocab.json",
17"tokenizer_config_file": "tokenizer_config.json",
18}
19# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
20
21
22class MarianTokenizer(PreTrainedTokenizer):
23"""Sentencepiece tokenizer for marian. Source and target languages have different SPM models.
24The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary.
25
26Examples::
27
28>>> from transformers import MarianTokenizer
29>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
30>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
31>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
32>>> batch_enc: BatchEncoding = tok.prepare_translation_batch(src_texts, tgt_texts=tgt_texts)
33>>> # keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask].
34>>> # model(**batch) should work
35"""
36
37vocab_files_names = vocab_files_names
38model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
39language_code_re = re.compile(">>.+<<") # type: re.Pattern
40
41def __init__(
42self,
43vocab,
44source_spm,
45target_spm,
46source_lang=None,
47target_lang=None,
48unk_token="<unk>",
49eos_token="</s>",
50pad_token="<pad>",
51model_max_length=512,
52**kwargs
53):
54super().__init__(
55# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
56model_max_length=model_max_length,
57eos_token=eos_token,
58unk_token=unk_token,
59pad_token=pad_token,
60**kwargs,
61)
62assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
63self.encoder = load_json(vocab)
64if self.unk_token not in self.encoder:
65raise KeyError("<unk> token must be in vocab")
66assert self.pad_token in self.encoder
67self.decoder = {v: k for k, v in self.encoder.items()}
68
69self.source_lang = source_lang
70self.target_lang = target_lang
71self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
72self.spm_files = [source_spm, target_spm]
73
74# load SentencePiece model for pre-processing
75self.spm_source = load_spm(source_spm)
76self.spm_target = load_spm(target_spm)
77self.current_spm = self.spm_source
78
79# Multilingual target side: default to using first supported language code.
80
81self._setup_normalizer()
82
83def _setup_normalizer(self):
84try:
85from sacremoses import MosesPunctNormalizer
86
87self.punc_normalizer = MosesPunctNormalizer(self.source_lang).normalize
88except (ImportError, FileNotFoundError):
89warnings.warn("Recommended: pip install sacremoses.")
90self.punc_normalizer = lambda x: x
91
92def normalize(self, x: str) -> str:
93"""Cover moses empty string edge case. They return empty list for '' input!"""
94return self.punc_normalizer(x) if x else ""
95
96def _convert_token_to_id(self, token):
97return self.encoder.get(token, self.encoder[self.unk_token])
98
99def remove_language_code(self, text: str):
100"""Remove language codes like <<fr>> before sentencepiece"""
101match = self.language_code_re.match(text)
102code: list = [match.group(0)] if match else []
103return code, self.language_code_re.sub("", text)
104
105def _tokenize(self, text: str) -> List[str]:
106code, text = self.remove_language_code(text)
107pieces = self.current_spm.EncodeAsPieces(text)
108return code + pieces
109
110def _convert_id_to_token(self, index: int) -> str:
111"""Converts an index (integer) in a token (str) using the encoder."""
112return self.decoder.get(index, self.unk_token)
113
114def convert_tokens_to_string(self, tokens: List[str]) -> str:
115"""Uses target language sentencepiece model"""
116return self.spm_target.DecodePieces(tokens)
117
118def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
119"""Build model inputs from a sequence by appending eos_token_id."""
120if token_ids_1 is None:
121return token_ids_0 + [self.eos_token_id]
122# We don't expect to process pairs, but leave the pair logic for API consistency
123return token_ids_0 + token_ids_1 + [self.eos_token_id]
124
125def prepare_translation_batch(
126self,
127src_texts: List[str],
128tgt_texts: Optional[List[str]] = None,
129max_length: Optional[int] = None,
130pad_to_max_length: bool = True,
131return_tensors: str = "pt",
132truncation_strategy="only_first",
133padding="longest",
134) -> BatchEncoding:
135"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
136Arguments:
137src_texts: list of src language texts
138tgt_texts: list of tgt language texts
139max_length: (None) defer to config (1024 for mbart-large-en-ro)
140pad_to_max_length: (bool)
141return_tensors: (str) default "pt" returns pytorch tensors, pass None to return lists.
142
143Returns:
144BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
145all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
146If no tgt_text is specified, the only keys will be input_ids and attention_mask.
147"""
148if "" in src_texts:
149raise ValueError(f"found empty string in src_texts: {src_texts}")
150self.current_spm = self.spm_source
151src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
152tokenizer_kwargs = dict(
153add_special_tokens=True,
154return_tensors=return_tensors,
155max_length=max_length,
156pad_to_max_length=pad_to_max_length,
157truncation_strategy=truncation_strategy,
158padding=padding,
159)
160model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
161
162if tgt_texts is None:
163return model_inputs
164
165self.current_spm = self.spm_target
166decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
167for k, v in decoder_inputs.items():
168model_inputs[f"decoder_{k}"] = v
169self.current_spm = self.spm_source
170return model_inputs
171
172@property
173def vocab_size(self) -> int:
174return len(self.encoder)
175
176def save_vocabulary(self, save_directory: str) -> Tuple[str]:
177"""save vocab file to json and copy spm files from their original path."""
178save_dir = Path(save_directory)
179assert save_dir.is_dir(), f"{save_directory} should be a directory"
180save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
181
182for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
183dest_path = save_dir / Path(f).name
184if not dest_path.exists():
185copyfile(f, save_dir / orig)
186
187return tuple(save_dir / f for f in self.vocab_files_names)
188
189def get_vocab(self) -> Dict:
190vocab = self.encoder.copy()
191vocab.update(self.added_tokens_encoder)
192return vocab
193
194def __getstate__(self) -> Dict:
195state = self.__dict__.copy()
196state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
197return state
198
199def __setstate__(self, d: Dict) -> None:
200self.__dict__ = d
201self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
202self.current_spm = self.spm_source
203self._setup_normalizer()
204
205def num_special_tokens_to_add(self, **unused):
206"""Just EOS"""
207return 1
208
209def _special_token_mask(self, seq):
210all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
211all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
212return [1 if x in all_special_ids else 0 for x in seq]
213
214def get_special_tokens_mask(
215self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
216) -> List[int]:
217"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
218if already_has_special_tokens:
219return self._special_token_mask(token_ids_0)
220elif token_ids_1 is None:
221return self._special_token_mask(token_ids_0) + [1]
222else:
223return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
224
225
226def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
227spm = sentencepiece.SentencePieceProcessor()
228spm.Load(path)
229return spm
230
231
232def save_json(data, path: str) -> None:
233with open(path, "w") as f:
234json.dump(data, f, indent=2)
235
236
237def load_json(path: str) -> Union[Dict, List]:
238with open(path, "r") as f:
239return json.load(f)
240