colossalai

Форк
0
442 строки · 16.5 Кб
1
"""
2
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
3
"""
4
"""Tokenization classes for ChatGLM."""
5
import os
6
from typing import Dict, List, Optional, Union
7

8
import numpy as np
9
import sentencepiece as spm
10
from transformers.tokenization_utils import PreTrainedTokenizer
11
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
12
from transformers.utils import PaddingStrategy, logging
13

14
logger = logging.get_logger(__name__)
15

16
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
17
    "THUDM/chatglm-6b": 2048,
18
}
19

20

21
class TextTokenizer:
22
    def __init__(self, model_path):
23
        self.sp = spm.SentencePieceProcessor()
24
        self.sp.Load(model_path)
25
        self.num_tokens = self.sp.vocab_size()
26

27
    def encode(self, text):
28
        return self.sp.EncodeAsIds(text)
29

30
    def decode(self, ids: List[int]):
31
        return self.sp.DecodeIds(ids)
32

33
    def tokenize(self, text):
34
        return self.sp.EncodeAsPieces(text)
35

36
    def convert_tokens_to_string(self, tokens):
37
        return self.sp.DecodePieces(tokens)
38

39
    def convert_tokens_to_ids(self, tokens):
40
        return [self.sp.PieceToId(token) for token in tokens]
41

42
    def convert_token_to_id(self, token):
43
        return self.sp.PieceToId(token)
44

45
    def convert_id_to_token(self, idx):
46
        return self.sp.IdToPiece(idx)
47

48
    def __len__(self):
49
        return self.num_tokens
50

51

52
class SPTokenizer:
53
    def __init__(
54
        self,
55
        vocab_file,
56
        num_image_tokens=20000,
57
        max_blank_length=80,
58
        byte_fallback=True,
59
    ):
60
        assert vocab_file is not None
61
        self.vocab_file = vocab_file
62
        self.num_image_tokens = num_image_tokens
63
        self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
64
        self.max_blank_length = max_blank_length
65
        self.byte_fallback = byte_fallback
66
        self.text_tokenizer = TextTokenizer(vocab_file)
67

68
    def _get_text_tokenizer(self):
69
        return self.text_tokenizer
70

71
    @staticmethod
72
    def get_blank_token(length: int):
73
        assert length >= 2
74
        return f"<|blank_{length}|>"
75

76
    @staticmethod
77
    def get_tab_token():
78
        return f"<|tab|>"
79

80
    @property
81
    def num_text_tokens(self):
82
        return self.text_tokenizer.num_tokens
83

84
    @property
85
    def num_tokens(self):
86
        return self.num_image_tokens + self.num_text_tokens
87

88
    @staticmethod
89
    def _encode_whitespaces(text: str, max_len: int = 80):
90
        text = text.replace("\t", SPTokenizer.get_tab_token())
91
        for i in range(max_len, 1, -1):
92
            text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
93
        return text
94

95
    def _preprocess(self, text: str, linebreak=True, whitespaces=True):
96
        if linebreak:
97
            text = text.replace("\n", "<n>")
98
        if whitespaces:
99
            text = self._encode_whitespaces(text, max_len=self.max_blank_length)
100
        return text
101

102
    def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
103
        """
104
        @param text: Text to encode.
105
        @param linebreak: Whether to encode newline (\n) in text.
106
        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
107
        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
108
        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
109
        """
110
        text = self._preprocess(text, linebreak, whitespaces)
111
        if not add_dummy_prefix:
112
            text = "<n>" + text
113
        tmp = self._get_text_tokenizer().encode(text)
114
        tokens = [x + self.num_image_tokens for x in tmp]
115
        return tokens if add_dummy_prefix else tokens[2:]
116

117
    def postprocess(self, text):
118
        text = text.replace("<n>", "\n")
119
        text = text.replace(SPTokenizer.get_tab_token(), "\t")
120
        for i in range(2, self.max_blank_length + 1):
121
            text = text.replace(self.get_blank_token(i), " " * i)
122
        return text
123

124
    def decode(self, text_ids: List[int]) -> str:
125
        ids = [int(_id) - self.num_image_tokens for _id in text_ids]
126
        ids = [_id for _id in ids if _id >= 0]
127
        text = self._get_text_tokenizer().decode(ids)
128
        text = self.postprocess(text)
129
        return text
130

131
    def decode_tokens(self, tokens: List[str]) -> str:
132
        text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
133
        text = self.postprocess(text)
134
        return text
135

136
    def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
137
        """
138
        @param text: Text to encode.
139
        @param linebreak: Whether to encode newline (\n) in text.
140
        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
141
        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
142
        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
143
        """
144
        text = self._preprocess(text, linebreak, whitespaces)
145
        if not add_dummy_prefix:
146
            text = "<n>" + text
147
        tokens = self._get_text_tokenizer().tokenize(text)
148
        return tokens if add_dummy_prefix else tokens[2:]
149

150
    def __getitem__(self, x: Union[int, str]):
151
        if isinstance(x, int):
152
            if x < self.num_image_tokens:
153
                return "<image_{}>".format(x)
154
            else:
155
                return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
156
        elif isinstance(x, str):
157
            if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
158
                return int(x[7:-1])
159
            else:
160
                return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
161
        else:
162
            raise ValueError("The key should be str or int.")
163

164

165
class ChatGLMTokenizer(PreTrainedTokenizer):
166
    """
167
    Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
168

169
    Args:
170
        vocab_file (`str`):
171
            Path to the vocabulary file.
172
    """
173

174
    vocab_files_names = {"vocab_file": "ice_text.model"}
175
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
176
    model_input_names = ["input_ids", "attention_mask", "position_ids"]
177

178
    def __init__(
179
        self,
180
        vocab_file,
181
        do_lower_case=False,
182
        remove_space=False,
183
        bos_token="<sop>",
184
        eos_token="<eop>",
185
        end_token="</s>",
186
        mask_token="[MASK]",
187
        gmask_token="[gMASK]",
188
        padding_side="left",
189
        pad_token="<pad>",
190
        unk_token="<unk>",
191
        num_image_tokens=20000,
192
        **kwargs,
193
    ) -> None:
194
        super().__init__(
195
            do_lower_case=do_lower_case,
196
            remove_space=remove_space,
197
            padding_side=padding_side,
198
            bos_token=bos_token,
199
            eos_token=eos_token,
200
            end_token=end_token,
201
            mask_token=mask_token,
202
            gmask_token=gmask_token,
203
            pad_token=pad_token,
204
            unk_token=unk_token,
205
            num_image_tokens=num_image_tokens,
206
            **kwargs,
207
        )
208

209
        self.do_lower_case = do_lower_case
210
        self.remove_space = remove_space
211
        self.vocab_file = vocab_file
212

213
        self.bos_token = bos_token
214
        self.eos_token = eos_token
215
        self.end_token = end_token
216
        self.mask_token = mask_token
217
        self.gmask_token = gmask_token
218

219
        self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
220

221
        """ Initialisation """
222

223
    @property
224
    def gmask_token_id(self) -> Optional[int]:
225
        if self.gmask_token is None:
226
            return None
227
        return self.convert_tokens_to_ids(self.gmask_token)
228

229
    @property
230
    def end_token_id(self) -> Optional[int]:
231
        """
232
        `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
233
        set.
234
        """
235
        if self.end_token is None:
236
            return None
237
        return self.convert_tokens_to_ids(self.end_token)
238

239
    @property
240
    def vocab_size(self):
241
        """Returns vocab size"""
242
        return self.sp_tokenizer.num_tokens
243

244
    def get_vocab(self):
245
        """Returns vocab as a dict"""
246
        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
247
        vocab.update(self.added_tokens_encoder)
248
        return vocab
249

250
    def preprocess_text(self, inputs):
251
        if self.remove_space:
252
            outputs = " ".join(inputs.strip().split())
253
        else:
254
            outputs = inputs
255

256
        if self.do_lower_case:
257
            outputs = outputs.lower()
258

259
        return outputs
260

261
    def _tokenize(self, text, **kwargs):
262
        """Returns a tokenized string."""
263
        text = self.preprocess_text(text)
264

265
        seq = self.sp_tokenizer.tokenize(text)
266

267
        return seq
268

269
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
270
        return self.sp_tokenizer.decode_tokens(tokens)
271

272
    def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
273
        if isinstance(token_ids, int):
274
            token_ids = [token_ids]
275
        if len(token_ids) == 0:
276
            return ""
277
        if self.pad_token_id in token_ids:  # remove pad
278
            token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
279
        return super()._decode(token_ids, **kwargs)
280

281
    def _convert_token_to_id(self, token):
282
        """Converts a token (str) in an id using the vocab."""
283
        return self.sp_tokenizer[token]
284

285
    def _convert_id_to_token(self, index):
286
        """Converts an index (integer) in a token (str) using the vocab."""
287
        return self.sp_tokenizer[index]
288

289
    def save_vocabulary(self, save_directory, filename_prefix=None):
290
        """
291
        Save the vocabulary and special tokens file to a directory.
292

293
        Args:
294
            save_directory (`str`):
295
                The directory in which to save the vocabulary.
296
            filename_prefix (`str`, *optional*):
297
                An optional prefix to add to the named of the saved files.
298

299
        Returns:
300
            `Tuple(str)`: Paths to the files saved.
301
        """
302
        if os.path.isdir(save_directory):
303
            vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
304
        else:
305
            vocab_file = save_directory
306

307
        with open(self.vocab_file, "rb") as fin:
308
            proto_str = fin.read()
309

310
        with open(vocab_file, "wb") as writer:
311
            writer.write(proto_str)
312

313
        return (vocab_file,)
314

315
    def build_inputs_with_special_tokens(
316
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
317
    ) -> List[int]:
318
        """
319
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
320
        adding special tokens. A BERT sequence has the following format:
321

322
        - single sequence: `[CLS] X [SEP]`
323
        - pair of sequences: `[CLS] A [SEP] B [SEP]`
324

325
        Args:
326
            token_ids_0 (`List[int]`):
327
                List of IDs to which the special tokens will be added.
328
            token_ids_1 (`List[int]`, *optional*):
329
                Optional second list of IDs for sequence pairs.
330

331
        Returns:
332
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
333
        """
334
        gmask_id = self.sp_tokenizer[self.gmask_token]
335
        self.sp_tokenizer[self.eos_token]
336
        token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
337
        if token_ids_1 is not None:
338
            token_ids_0 = token_ids_0 + token_ids_1
339
        return token_ids_0
340

341
    def _pad(
342
        self,
343
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
344
        max_length: Optional[int] = None,
345
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
346
        pad_to_multiple_of: Optional[int] = None,
347
        return_attention_mask: Optional[bool] = None,
348
    ) -> dict:
349
        """
350
        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
351

352
        Args:
353
            encoded_inputs:
354
                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
355
            max_length: maximum length of the returned list and optionally padding length (see below).
356
                Will truncate by taking into account the special tokens.
357
            padding_strategy: PaddingStrategy to use for padding.
358

359
                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
360
                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
361
                - PaddingStrategy.DO_NOT_PAD: Do not pad
362
                The tokenizer padding sides are defined in self.padding_side:
363

364
                    - 'left': pads on the left of the sequences
365
                    - 'right': pads on the right of the sequences
366
            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
367
                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
368
                `>= 7.5` (Volta).
369
            return_attention_mask:
370
                (optional) Set to False to avoid returning attention mask (default: set to model specifics)
371
        """
372
        # Load from model defaults
373
        bos_token_id = self.sp_tokenizer[self.bos_token]
374
        mask_token_id = self.sp_tokenizer[self.mask_token]
375
        gmask_token_id = self.sp_tokenizer[self.gmask_token]
376
        assert self.padding_side == "left"
377

378
        required_input = encoded_inputs[self.model_input_names[0]]
379
        seq_length = len(required_input)
380

381
        if padding_strategy == PaddingStrategy.LONGEST:
382
            max_length = len(required_input)
383

384
        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
385
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
386

387
        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
388

389
        # Initialize attention mask if not present.
390
        if max_length is not None:
391
            if "attention_mask" not in encoded_inputs:
392
                if bos_token_id in required_input:
393
                    context_length = required_input.index(bos_token_id)
394
                else:
395
                    context_length = seq_length
396
                attention_mask = np.ones((1, seq_length, seq_length))
397
                attention_mask = np.tril(attention_mask)
398
                attention_mask[:, :, :context_length] = 1
399
                attention_mask = np.bool_(attention_mask < 0.5)
400
                encoded_inputs["attention_mask"] = attention_mask
401

402
            if "position_ids" not in encoded_inputs:
403
                if bos_token_id in required_input:
404
                    context_length = required_input.index(bos_token_id)
405
                else:
406
                    context_length = seq_length
407
                position_ids = np.arange(seq_length, dtype=np.int64)
408
                mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
409
                if mask_token in required_input:
410
                    mask_position = required_input.index(mask_token)
411
                    position_ids[context_length:] = mask_position
412
                block_position_ids = np.concatenate(
413
                    [
414
                        np.zeros(context_length, dtype=np.int64),
415
                        np.arange(1, seq_length - context_length + 1, dtype=np.int64),
416
                    ]
417
                )
418
                encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
419

420
        if needs_to_be_padded:
421
            difference = max_length - len(required_input)
422

423
            if "attention_mask" in encoded_inputs:
424
                encoded_inputs["attention_mask"] = np.pad(
425
                    encoded_inputs["attention_mask"],
426
                    pad_width=[(0, 0), (difference, 0), (difference, 0)],
427
                    mode="constant",
428
                    constant_values=True,
429
                )
430
            if "token_type_ids" in encoded_inputs:
431
                encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
432
                    "token_type_ids"
433
                ]
434
            if "special_tokens_mask" in encoded_inputs:
435
                encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
436
            if "position_ids" in encoded_inputs:
437
                encoded_inputs["position_ids"] = np.pad(
438
                    encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
439
                )
440
            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
441

442
        return encoded_inputs
443

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.