CSS-LM

Форк
0
/
tokenization_marian.py 
239 строк · 9.3 Кб
1
import json
2
import re
3
import warnings
4
from pathlib import Path
5
from shutil import copyfile
6
from typing import Dict, List, Optional, Tuple, Union
7

8
import sentencepiece
9

10
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
11

12

13
vocab_files_names = {
14
    "source_spm": "source.spm",
15
    "target_spm": "target.spm",
16
    "vocab": "vocab.json",
17
    "tokenizer_config_file": "tokenizer_config.json",
18
}
19
# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
20

21

22
class MarianTokenizer(PreTrainedTokenizer):
23
    """Sentencepiece tokenizer for marian. Source and target languages have different SPM models.
24
    The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary.
25

26
    Examples::
27

28
        >>> from transformers import MarianTokenizer
29
        >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
30
        >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
31
        >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."]  # optional
32
        >>> batch_enc: BatchEncoding = tok.prepare_translation_batch(src_texts, tgt_texts=tgt_texts)
33
        >>> # keys  [input_ids, attention_mask, decoder_input_ids,  decoder_attention_mask].
34
        >>> # model(**batch) should work
35
    """
36

37
    vocab_files_names = vocab_files_names
38
    model_input_names = ["attention_mask"]  # actually attention_mask, decoder_attention_mask
39
    language_code_re = re.compile(">>.+<<")  # type: re.Pattern
40

41
    def __init__(
42
        self,
43
        vocab,
44
        source_spm,
45
        target_spm,
46
        source_lang=None,
47
        target_lang=None,
48
        unk_token="<unk>",
49
        eos_token="</s>",
50
        pad_token="<pad>",
51
        model_max_length=512,
52
        **kwargs
53
    ):
54
        super().__init__(
55
            # bos_token=bos_token,  unused. Start decoding with config.decoder_start_token_id
56
            model_max_length=model_max_length,
57
            eos_token=eos_token,
58
            unk_token=unk_token,
59
            pad_token=pad_token,
60
            **kwargs,
61
        )
62
        assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
63
        self.encoder = load_json(vocab)
64
        if self.unk_token not in self.encoder:
65
            raise KeyError("<unk> token must be in vocab")
66
        assert self.pad_token in self.encoder
67
        self.decoder = {v: k for k, v in self.encoder.items()}
68

69
        self.source_lang = source_lang
70
        self.target_lang = target_lang
71
        self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
72
        self.spm_files = [source_spm, target_spm]
73

74
        # load SentencePiece model for pre-processing
75
        self.spm_source = load_spm(source_spm)
76
        self.spm_target = load_spm(target_spm)
77
        self.current_spm = self.spm_source
78

79
        # Multilingual target side: default to using first supported language code.
80

81
        self._setup_normalizer()
82

83
    def _setup_normalizer(self):
84
        try:
85
            from sacremoses import MosesPunctNormalizer
86

87
            self.punc_normalizer = MosesPunctNormalizer(self.source_lang).normalize
88
        except (ImportError, FileNotFoundError):
89
            warnings.warn("Recommended: pip install sacremoses.")
90
            self.punc_normalizer = lambda x: x
91

92
    def normalize(self, x: str) -> str:
93
        """Cover moses empty string edge case. They return empty list for '' input!"""
94
        return self.punc_normalizer(x) if x else ""
95

96
    def _convert_token_to_id(self, token):
97
        return self.encoder.get(token, self.encoder[self.unk_token])
98

99
    def remove_language_code(self, text: str):
100
        """Remove language codes like <<fr>> before sentencepiece"""
101
        match = self.language_code_re.match(text)
102
        code: list = [match.group(0)] if match else []
103
        return code, self.language_code_re.sub("", text)
104

105
    def _tokenize(self, text: str) -> List[str]:
106
        code, text = self.remove_language_code(text)
107
        pieces = self.current_spm.EncodeAsPieces(text)
108
        return code + pieces
109

110
    def _convert_id_to_token(self, index: int) -> str:
111
        """Converts an index (integer) in a token (str) using the encoder."""
112
        return self.decoder.get(index, self.unk_token)
113

114
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
115
        """Uses target language sentencepiece model"""
116
        return self.spm_target.DecodePieces(tokens)
117

118
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
119
        """Build model inputs from a sequence by appending eos_token_id."""
120
        if token_ids_1 is None:
121
            return token_ids_0 + [self.eos_token_id]
122
        # We don't expect to process pairs, but leave the pair logic for API consistency
123
        return token_ids_0 + token_ids_1 + [self.eos_token_id]
124

125
    def prepare_translation_batch(
126
        self,
127
        src_texts: List[str],
128
        tgt_texts: Optional[List[str]] = None,
129
        max_length: Optional[int] = None,
130
        pad_to_max_length: bool = True,
131
        return_tensors: str = "pt",
132
        truncation_strategy="only_first",
133
        padding="longest",
134
    ) -> BatchEncoding:
135
        """Prepare model inputs for translation. For best performance, translate one sentence at a time.
136
        Arguments:
137
            src_texts: list of src language texts
138
            tgt_texts: list of tgt language texts
139
            max_length: (None) defer to config (1024 for mbart-large-en-ro)
140
            pad_to_max_length: (bool)
141
            return_tensors: (str) default "pt" returns pytorch tensors, pass None to return lists.
142

143
        Returns:
144
            BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids,  decoder_attention_mask]
145
            all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
146
            If no tgt_text is specified, the only keys will be input_ids and attention_mask.
147
        """
148
        if "" in src_texts:
149
            raise ValueError(f"found empty string in src_texts: {src_texts}")
150
        self.current_spm = self.spm_source
151
        src_texts = [self.normalize(t) for t in src_texts]  # this does not appear to do much
152
        tokenizer_kwargs = dict(
153
            add_special_tokens=True,
154
            return_tensors=return_tensors,
155
            max_length=max_length,
156
            pad_to_max_length=pad_to_max_length,
157
            truncation_strategy=truncation_strategy,
158
            padding=padding,
159
        )
160
        model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
161

162
        if tgt_texts is None:
163
            return model_inputs
164

165
        self.current_spm = self.spm_target
166
        decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
167
        for k, v in decoder_inputs.items():
168
            model_inputs[f"decoder_{k}"] = v
169
        self.current_spm = self.spm_source
170
        return model_inputs
171

172
    @property
173
    def vocab_size(self) -> int:
174
        return len(self.encoder)
175

176
    def save_vocabulary(self, save_directory: str) -> Tuple[str]:
177
        """save vocab file to json and copy spm files from their original path."""
178
        save_dir = Path(save_directory)
179
        assert save_dir.is_dir(), f"{save_directory} should be a directory"
180
        save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
181

182
        for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
183
            dest_path = save_dir / Path(f).name
184
            if not dest_path.exists():
185
                copyfile(f, save_dir / orig)
186

187
        return tuple(save_dir / f for f in self.vocab_files_names)
188

189
    def get_vocab(self) -> Dict:
190
        vocab = self.encoder.copy()
191
        vocab.update(self.added_tokens_encoder)
192
        return vocab
193

194
    def __getstate__(self) -> Dict:
195
        state = self.__dict__.copy()
196
        state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
197
        return state
198

199
    def __setstate__(self, d: Dict) -> None:
200
        self.__dict__ = d
201
        self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
202
        self.current_spm = self.spm_source
203
        self._setup_normalizer()
204

205
    def num_special_tokens_to_add(self, **unused):
206
        """Just EOS"""
207
        return 1
208

209
    def _special_token_mask(self, seq):
210
        all_special_ids = set(self.all_special_ids)  # call it once instead of inside list comp
211
        all_special_ids.remove(self.unk_token_id)  # <unk> is only sometimes special
212
        return [1 if x in all_special_ids else 0 for x in seq]
213

214
    def get_special_tokens_mask(
215
        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
216
    ) -> List[int]:
217
        """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
218
        if already_has_special_tokens:
219
            return self._special_token_mask(token_ids_0)
220
        elif token_ids_1 is None:
221
            return self._special_token_mask(token_ids_0) + [1]
222
        else:
223
            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
224

225

226
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
227
    spm = sentencepiece.SentencePieceProcessor()
228
    spm.Load(path)
229
    return spm
230

231

232
def save_json(data, path: str) -> None:
233
    with open(path, "w") as f:
234
        json.dump(data, f, indent=2)
235

236

237
def load_json(path: str) -> Union[Dict, List]:
238
    with open(path, "r") as f:
239
        return json.load(f)
240

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.