paddlenlp

Форк
0
635 строк · 24.2 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
import math
17
import os
18
import re
19
import time
20

21
import numpy as np
22
import paddle
23
from ppfleetx.data.tokenizers import GPTTokenizer
24
from ppfleetx.distributed.apis import env
25
from ppfleetx.utils.log import logger
26

27
# TODO(haohongxiang): to solve the problem of cross-reference
28
import paddlenlp  # noqa: F401
29
from paddlenlp.transformers.gpt.tokenizer import GPTChineseTokenizer
30

31
mode_to_index = {"Train": 0, "Eval": 1, "Test": 2}
32

33
MODEL_CLASSES = {
34
    "GPT": (GPTTokenizer, "gpt2"),
35
    "GPT-cn": (GPTChineseTokenizer, "gpt-cpm-large-cn"),
36
}
37

38

39
class GPTDataset(paddle.io.Dataset):
40
    def __init__(self, input_dir, split, max_seq_len, num_samples, mode, model_type="GPT", seed=1234):
41

42
        files = get_train_data_file(input_dir)
43
        files.sort()
44
        input_dir = [files[0]]
45

46
        local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
47

48
        if local_rank == 0:
49
            try:
50
                import ppfleetx.data.data_tools.cpp.fast_index_map_helpers
51
            except Exception:
52
                start_time = time.time()
53
                print("> compiling dataset index builder ...")
54
                from ppfleetx.data.data_tools.cpp.compile import compile_helper
55

56
                compile_helper()
57
                print(
58
                    ">>> done with dataset index builder. Compilation time: {:.3f} "
59
                    "seconds".format(time.time() - start_time),
60
                    flush=True,
61
                )
62

63
        device_world_size = paddle.distributed.get_world_size()
64

65
        if device_world_size > 1 and local_rank != 0:
66
            while True:
67
                try:
68
                    import ppfleetx.data.data_tools.cpp.fast_index_map_helpers  # noqa: F401, F811
69

70
                    break
71
                except Exception:
72
                    print("> wait for helpers to be compiled!")
73
                    time.sleep(1)
74

75
        try:
76
            data_world_size = env.get_data_world_size()
77

78
            logger.info(
79
                "The distributed run, total device num:{}, distinct dataflow num:{}.".format(
80
                    device_world_size, data_world_size
81
                )
82
            )
83
        except AttributeError:
84
            pass
85

86
        assert len(input_dir) == 1, "GPT only support one dataset for now."
87

88
        input_prefix = input_dir[0]
89

90
        if os.path.isfile(input_prefix + "_ids.npz"):
91
            logger.warning("You are using compatible dataset, please make new dataset as the readme!")
92
            process_data = np.load(input_prefix + "_ids.npz", mmap_mode="r+", allow_pickle=True)
93
            sample_ids = process_data["ids"]
94
            sample_lens = process_data["lens"].astype("int32")
95
        else:
96
            for suffix in ["_ids.npy", "_idx.npz"]:
97
                if not os.path.isfile(input_prefix + suffix):
98
                    raise ValueError("File Not found, %s" % (input_prefix + suffix))
99

100
            sample_ids = np.load(input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
101
            # All documment ids, extend as 1-D array.
102

103
            process_data = np.load(input_prefix + "_idx.npz")
104
            # The len(sample_lens) num of docs
105
            # The sum(sample_lens) should equal len(sample_ids)
106
            sample_lens = process_data["lens"]
107

108
        splits = get_train_valid_test_split_(split, len(sample_lens))
109
        assert len(sample_lens) >= splits[-1], "The document nums should larger than max of splits, but %s < %s" % (
110
            len(sample_lens),
111
            splits[-1],
112
        )
113

114
        tokenizer_class, pretrained_name = MODEL_CLASSES[model_type]
115
        tokenizer = tokenizer_class.from_pretrained(pretrained_name)
116

117
        self.input_dir = input_dir
118
        self.max_seq_len = max_seq_len
119
        self.mode = mode
120
        self.name = "gpt_" + mode
121
        self.eos_id = tokenizer.eos_token_id
122
        self.sample_ids = sample_ids
123
        self.sample_lens = sample_lens
124
        self.build_data_file = local_rank == 0
125

126
        if mode in mode_to_index.keys():
127
            index = mode_to_index[mode]
128
        else:
129
            raise ValueError("valid str value for 'mode'")
130

131
        documents = np.arange(splits[index], splits[index + 1])
132
        if documents is None:
133
            document_ids = np.arange(0, self.sample_lens.shape[0])
134
        else:
135
            document_ids = documents
136

137
        self.doc_idx, self.sample_idx, self.shuffle_idx = construct_samples_and_shuffle_data(
138
            self.name,
139
            input_prefix,
140
            document_ids,
141
            self.sample_lens,
142
            num_samples,
143
            max_seq_len,
144
            seed,
145
            self.build_data_file,
146
        )
147

148
        # The doc cumsum start pos
149
        self.start_pos = [0] + np.cumsum(self.sample_lens).tolist()
150

151
    def _construct_sample(self, tokens):
152
        tokens = np.array(tokens).astype("int64").tolist()
153
        labels = tokens[1:]
154
        tokens = tokens[:-1]
155
        seq_length = len(tokens)
156
        # Attention mask for the attention calulate
157
        # attention_mask = np.tri(seq_length, seq_length).reshape((1, seq_length,
158
        #  seq_length))
159
        # The pad and eos tokens do not contribute the loss
160
        loss_mask = np.ones(seq_length, dtype="float32")
161
        loss_mask[tokens == self.eos_id] = 0.0
162
        position_ids = np.arange(0, seq_length, dtype="int64")
163

164
        labels = np.array(labels).astype("int64")
165
        tokens = np.array(tokens).astype("int64")
166
        if self.mode == "Test":
167
            return [tokens, position_ids]
168
        else:
169
            return [tokens, position_ids, labels, loss_mask]
170

171
    def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f, offset_l):
172
        """
173
        The input means:
174
            doc_index_f: data from the first doc.
175
            doc_index_l: data from the last doc.
176
            offset_f: offset of the first doc.
177
            offset_l: offset of the last doc.
178
        """
179
        # Data from the sample doc. just select the needed ids.
180
        if doc_index_f == doc_index_l:
181
            current_start_pos = self.start_pos[self.doc_idx[doc_index_f]]
182
            return self.sample_ids[current_start_pos + offset_f : current_start_pos + offset_l + 1].tolist()
183

184
        # Data from multi docs.
185
        else:
186
            current_start_pos = self.start_pos[self.doc_idx[doc_index_f]]
187
            next_start_pos = self.start_pos[self.doc_idx[doc_index_f] + 1]
188
            tokens = self.sample_ids[current_start_pos + offset_f : next_start_pos].tolist()
189
            for i in range(doc_index_f + 1, doc_index_l):
190
                current_start_pos = self.start_pos[self.doc_idx[i]]
191
                next_start_pos = self.start_pos[self.doc_idx[i] + 1]
192
                tokens.extend(self.sample_ids[current_start_pos:next_start_pos].tolist())
193
            last_start_pos = self.start_pos[self.doc_idx[doc_index_l]]
194
            tokens.extend(self.sample_ids[last_start_pos : last_start_pos + offset_l + 1].tolist())
195

196
        return tokens
197

198
    def __getitem__(self, index):
199
        idx = self.shuffle_idx[index]
200
        # Start and end documents and offsets.
201
        doc_index_f = self.sample_idx[idx][0]
202
        doc_index_l = self.sample_idx[idx + 1][0]
203
        offset_f = self.sample_idx[idx][1]
204
        offset_l = self.sample_idx[idx + 1][1]
205
        tokens = self._get_single_sample_from_idx(doc_index_f, doc_index_l, offset_f, offset_l)
206
        return self._construct_sample(tokens)
207

208
    def __len__(self):
209
        return self.sample_idx.shape[0] - 1
210

211

212
def get_train_data_file(input_dir):
213
    files = [
214
        os.path.join(input_dir, f)
215
        for f in os.listdir(input_dir)
216
        if (os.path.isfile(os.path.join(input_dir, f)) and str(f).endswith("_idx.npz"))
217
    ]
218
    files = [x.replace("_idx.npz", "") for x in files]
219
    if len(files) == 0:
220
        logger.warning(
221
            "Not found dataset with name of xxx_ids.npy and xxx_idx.npz! Try to found old compatible xxx_ids.npz file."
222
        )
223
    else:
224
        return files
225

226
    files = [
227
        os.path.join(input_dir, f)
228
        for f in os.listdir(input_dir)
229
        if (os.path.isfile(os.path.join(input_dir, f)) and str(f).endswith("_ids.npz"))
230
    ]
231

232
    files = [x.replace("_ids.npz", "") for x in files]
233

234
    if len(files) == 0:
235
        raise RuntimeError("Not found dataset with name of xxx_ids.npz in given input_dir '{}'! ".format(input_dir))
236
    else:
237
        return files
238

239

240
def get_train_valid_test_split_(splits, size):
241
    """
242
    Get dataset splits from comma or '/' separated string list.
243
    """
244

245
    splits = [float(s) for s in splits]
246
    while len(splits) < 3:
247
        splits.append(0.0)
248
    splits = splits[:3]
249
    splits_sum = sum(splits)
250
    assert splits_sum > 0.0
251
    splits = [split / splits_sum for split in splits]
252
    splits_index = [0]
253
    for index, split in enumerate(splits):
254
        splits_index.append(splits_index[index] + int(round(split * float(size))))
255
    diff = splits_index[-1] - size
256
    for index in range(1, len(splits_index)):
257
        splits_index[index] -= diff
258
    assert len(splits_index) == 4
259
    assert splits_index[-1] == size
260
    return splits_index
261

262

263
def construct_samples_and_shuffle_data(
264
    name, data_prefix, documents, sizes, num_samples, seq_length, seed, build_data_file
265
):
266
    """
267
    documents: document index from 0 to len(docs)
268
    sizes: the length list of all docs.
269
    num_samples: total step*bs iterations of data.
270
    seq_length: the sequence length.
271
    sum(sizes) = tokens_per_epoch
272
    data_nums = num_samples *  micro_batch_size
273
    num_epochs = (data_nums + 1) // sum(sizes)
274
    len(doc_idx) = num_epochs * sum(sizes)
275
    """
276
    # Number of tokens in each epoch and number of required epochs.
277
    tokens_per_epoch = _num_tokens(documents, sizes)
278
    num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
279
    # Rng state
280
    np_rng = np.random.RandomState(seed=seed)
281

282
    # Filename of the index mappings.
283
    _filename = data_prefix
284
    _filename += "_{}_indexmap".format(name)
285
    _filename += "_{}ns".format(num_samples)
286
    _filename += "_{}sl".format(seq_length)
287
    doc_idx_filename = _filename + "_doc_idx.npy"
288
    sample_idx_filename = _filename + "_sample_idx.npy"
289
    shuffle_idx_filename = _filename + "_shuffle_idx.npy"
290

291
    # Sava random state
292
    savedState = np_rng.get_state()
293
    # Build the indexed mapping if not exist.
294
    if build_data_file:
295
        if (
296
            (not os.path.isfile(doc_idx_filename))
297
            or (not os.path.isfile(sample_idx_filename))
298
            or (not os.path.isfile(shuffle_idx_filename))
299
        ):
300
            if num_epochs == 1:
301
                separate_last_epoch = False
302
            else:
303
                num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length
304
                last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one
305
                assert last_epoch_num_samples >= 0, "last epoch number of samples should be non-negative."
306
                num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
307
                assert last_epoch_num_samples < (
308
                    num_samples_per_epoch + 1
309
                ), "last epoch number of samples exceeded max value."
310
                separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch)
311
            # Note. len(doc_idx) = num_epochs * len(doc)
312
            start_time = time.time()
313
            doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch)
314
            np.save(doc_idx_filename, doc_idx, allow_pickle=True)
315
            print(
316
                " > elasped time to build and save doc-idx mapping "
317
                "(seconds): {:4f}".format(time.time() - start_time)
318
            )
319
            # sample-idx. pos of each seq_len of data.
320
            start_time = time.time()
321
            assert doc_idx.dtype == np.int32
322
            assert sizes.dtype == np.int32
323

324
            from ppfleetx.data.data_tools.cpp import fast_index_map_helpers
325

326
            sample_idx = fast_index_map_helpers.build_sample_idx(
327
                sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch
328
            )
329
            # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
330
            #                                num_epochs, tokens_per_epoch)
331

332
            np.save(sample_idx_filename, sample_idx, allow_pickle=True)
333
            print(
334
                " > elasped time to build and save sample-idx mapping "
335
                "(seconds): {:4f}".format(time.time() - start_time)
336
            )
337

338
            # shuffle-idx.
339
            start_time = time.time()
340

341
            if separate_last_epoch:
342
                num_samples_ = num_samples_from_epochs_minus_one
343
            else:
344
                num_samples_ = sample_idx.shape[0] - 1
345

346
            # Shuffle all seq len data.
347
            shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng)
348
            np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
349
            print(
350
                " > elasped time to build and save shuffle-idx mapping"
351
                " (seconds): {:4f}".format(time.time() - start_time)
352
            )
353

354
    else:
355
        while True:
356
            if (
357
                (not os.path.isfile(doc_idx_filename))
358
                or (not os.path.isfile(sample_idx_filename))
359
                or (not os.path.isfile(shuffle_idx_filename))
360
            ):
361
                time.sleep(3)
362
            else:
363
                try:
364
                    np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r")
365
                    break
366
                except Exception:
367
                    print("%s file is still writing or damaged, please wait a moment." % shuffle_idx_filename)
368
                    time.sleep(3)
369

370
    # Restore random state
371
    np_rng.set_state(savedState)
372

373
    try:
374
        if paddle.distributed.get_world_size() > 1:
375
            if paddle.in_dynamic_mode():
376
                paddle.distributed.barrier()
377
    except AssertionError:
378
        pass
379

380
    # Load mappings.
381
    doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r")
382
    sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r")
383
    shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r")
384
    return doc_idx, sample_idx, shuffle_idx
385

386

387
def _num_tokens(documents, lens):
388
    """Total number of tokens in the dataset."""
389
    return np.sum(lens[documents])
390

391

392
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
393
    """Based on number of samples and sequence lenght, calculate how many
394
    epochs will be needed."""
395
    num_epochs = 0
396
    total_tokens = 0
397
    while True:
398
        num_epochs += 1
399
        total_tokens += tokens_per_epoch
400
        if ((total_tokens - 1) // seq_length) >= num_samples:
401
            return num_epochs
402

403

404
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
405
    """
406
    Build an array with length = number-of-epochs * number-of-documents.
407
    Each index is mapped to a corresponding document.
408
    """
409
    if not separate_last_epoch or num_epochs == 1:
410
        doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1]
411
        doc_idx[:] = documents
412
        # The documents repeat num_epochs times.
413
        doc_idx = doc_idx.reshape(-1)
414
        doc_idx = doc_idx.astype(np.int32)
415
        np_rng.shuffle(doc_idx)
416
        return doc_idx
417

418
    doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
419
    doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
420
    return np.concatenate((doc_idx_first, doc_idx_last))
421

422

423
def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch):
424
    """
425
    num_samples + 1, pos of bs data
426
    the distance between two points for sample idx is bs tokens.
427
    """
428
    num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
429
    sample_idx = np.zeros([int(num_samples) + 1, 2], dtype=np.int32)
430

431
    sample_index = 0
432
    doc_idx_index = 0
433
    doc_offset = 0
434
    sample_idx[sample_index][0] = doc_idx_index
435
    sample_idx[sample_index][1] = doc_offset
436
    sample_index += 1
437
    while sample_index <= num_samples:
438
        remaining_seq_length = seq_length + 1
439
        while remaining_seq_length != 0:
440
            doc_id = doc_idx[doc_idx_index]
441
            doc_length = sizes[doc_id] - doc_offset
442
            remaining_seq_length -= doc_length
443
            if remaining_seq_length <= 0:
444
                doc_offset += remaining_seq_length + doc_length - 1
445
                remaining_seq_length = 0
446
            else:
447
                doc_idx_index += 1
448
                doc_offset = 0
449
        sample_idx[sample_index][0] = doc_idx_index
450
        sample_idx[sample_index][1] = doc_offset
451
        sample_index += 1
452

453
    return sample_idx
454

455

456
def _build_shuffle_idx(num_samples, total_size, np_rng):
457
    dtype_ = np.uint32
458
    if total_size >= (np.iinfo(np.uint32).max - 1):
459
        dtype_ = np.int64
460

461
    shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
462
    np_rng.shuffle(shuffle_idx_first)
463
    if num_samples == total_size:
464
        return shuffle_idx_first
465

466
    shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
467
    np_rng.shuffle(shuffle_idx_last)
468

469
    return np.concatenate((shuffle_idx_first, shuffle_idx_last))
470

471

472
class LM_Eval_Dataset(paddle.io.Dataset):
473
    def __init__(self, input_dir, max_seq_len, overlapping_eval=None, model_type="GPT", **kwargs):
474
        tokenizer_class, pretrained_name = MODEL_CLASSES[model_type]
475
        tokenizer = tokenizer_class.from_pretrained(pretrained_name)
476

477
        with open(input_dir, "rb") as reader:
478
            entire_data = reader.read().decode("utf-8")
479

480
        self.num_original_tokens = len(entire_data.strip().split(" "))
481
        entire_data = self._wikitext_detokenizer(entire_data)
482
        self.tokens = tokenizer.encode(entire_data)
483
        self.num_tokenized_tokens = len(self.tokens)
484
        print("Original Tokens: %d, Detokenized tokens: %d" % (self.num_original_tokens, self.num_tokenized_tokens))
485

486
        self.seq_len = max_seq_len
487
        self.pad_idx = tokenizer.eos_token_id
488
        self.overlapping_eval = overlapping_eval
489
        if self.overlapping_eval is None:
490
            self.overlapping_eval = self.seq_len
491
        self.overlapping_eval = max(1, self.overlapping_eval)
492

493
        self.total_targets = len(self.tokens) - 1
494
        # remove first sequence tokens
495
        targets = max(self.total_targets - self.overlapping_eval, 0)
496
        self.total_sequences = max(math.ceil(targets / self.overlapping_eval) + 1, 1)
497

498
    def __len__(self):
499
        return self.total_sequences
500

501
    def _construct_sample(self, tokens):
502
        tokens = np.array(tokens).astype("int64").tolist()
503
        labels = tokens[1:]
504
        tokens = tokens[:-1]
505
        seq_length = len(tokens)
506
        # attention mask for the attention calulate
507
        attention_mask = np.tri(seq_length, seq_length).reshape((1, seq_length, seq_length))
508

509
        # the pad and eos tokens do not contribute the loss
510
        loss_mask = np.ones(seq_length, dtype="float32")
511
        loss_mask[tokens == self.pad_idx] = 0.0
512
        position_ids = np.arange(0, seq_length, dtype="int64")
513

514
        # -INF mask value as default
515
        # attention_mask = (attention_mask - 1.0) * 1e9
516
        # Bool mask of attention
517
        attention_mask = attention_mask.astype("float32")
518
        return [tokens, loss_mask, attention_mask, position_ids, labels]
519

520
    def __getitem__(self, idx):
521
        start_idx = idx * self.overlapping_eval
522
        end_idx = start_idx + self.seq_len
523
        tokens = self.tokens[start_idx : end_idx + 1]
524
        num_tokens = len(tokens)
525
        if num_tokens < self.seq_len + 1:
526
            num_pad = self.seq_len + 1 - num_tokens
527
            tokens += [self.pad_idx] * num_pad
528
        [tokens, loss_mask, attention_mask, position_ids, labels] = self._construct_sample(tokens)
529
        if self.overlapping_eval != self.seq_len and idx != 0:
530
            loss_mask[: -self.overlapping_eval] *= 0
531

532
        return [
533
            tokens,
534
            loss_mask,
535
            attention_mask,
536
            position_ids,
537
            labels,
538
            np.array([self.num_original_tokens, self.num_tokenized_tokens]),
539
        ]
540

541
    def _wikitext_detokenizer(self, string):
542
        # contractions
543
        string = string.replace("s '", "s'")
544
        string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
545
        # number separators
546
        string = string.replace(" @-@ ", "-")
547
        string = string.replace(" @,@ ", ",")
548
        string = string.replace(" @.@ ", ".")
549
        # punctuation
550
        string = string.replace(" : ", ": ")
551
        string = string.replace(" ; ", "; ")
552
        string = string.replace(" . ", ". ")
553
        string = string.replace(" ! ", "! ")
554
        string = string.replace(" ? ", "? ")
555
        string = string.replace(" , ", ", ")
556
        # double brackets
557
        string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
558
        string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
559
        string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
560
        string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
561
        string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
562
        # miscellaneous
563
        string = string.replace("= = = =", "====")
564
        string = string.replace("= = =", "===")
565
        string = string.replace("= =", "==")
566
        string = string.replace(" " + chr(176) + " ", chr(176))
567
        string = string.replace(" \n", "\n")
568
        string = string.replace("\n ", "\n")
569
        string = string.replace(" N ", " 1 ")
570
        string = string.replace(" 's", "'s")
571
        return string
572

573

574
class Lambada_Eval_Dataset(paddle.io.Dataset):
575
    def __init__(self, input_dir, max_seq_len, model_type="GPT", **kwargs):
576
        tokenizer_class, pretrained_name = MODEL_CLASSES[model_type]
577
        tokenizer = tokenizer_class.from_pretrained(pretrained_name)
578

579
        tokenized_data = []
580
        tokenized_label = []
581
        with open(input_dir, "r") as f:
582
            for line in f.readlines():
583
                text = json.loads(line)["text"]
584
                tokens, labels = self._get_tokens(tokenizer, text)
585
                tokenized_data.append(tokens)
586
                tokenized_label.append(labels)
587

588
        self.pad_idx = tokenizer.eos_token_id
589
        self.seq_len = max_seq_len
590
        self.tokens = tokenized_data
591
        self.labels = tokenized_label
592

593
    def __len__(self):
594
        return len(self.tokens)
595

596
    def _construct_sample(self, tokens):
597
        tokens = np.array(tokens).astype("int64").tolist()
598
        labels = tokens[1:]
599
        tokens = tokens[:-1]
600

601
        seq_length = len(tokens)
602
        # attention mask for the attention calulate
603
        attention_mask = np.tri(seq_length, seq_length).reshape((1, seq_length, seq_length))
604

605
        # the pad and eos tokens do not contribute the loss
606
        position_ids = np.arange(0, seq_length, dtype="int64")
607

608
        # -INF mask value as default
609
        # attention_mask = (attention_mask - 1.0) * 1e9
610
        # Bool mask of attention
611
        attention_mask = attention_mask.astype("float32")
612
        return [tokens, attention_mask, position_ids, labels]
613

614
    def __getitem__(self, idx):
615
        tokens = self.tokens[idx][: self.seq_len]
616
        labels = self.labels[idx]
617
        tokens = tokens + labels
618
        num_tokens = len(tokens)
619
        if num_tokens < self.seq_len + 1:
620
            num_pad = self.seq_len + 1 - num_tokens
621
            tokens += [self.pad_idx] * num_pad
622
        loss_mask = np.zeros(self.seq_len, dtype="float32")
623
        loss_mask[num_tokens - len(labels) - 1 : num_tokens - 1] = 1.0
624
        [tokens, attention_mask, position_ids, labels] = self._construct_sample(tokens)
625
        return [tokens, loss_mask, attention_mask, position_ids, labels, np.array([self.__len__()])]
626

627
    def _get_tokens(self, tokenizer, text, strict=True):
628
        if not strict:
629
            tokens = tokenizer.encode(text)
630
            return tokens[:-1], [tokens[-1]]
631
        last_token = text.split()[-1]
632
        start_idx = text.rfind(last_token)
633
        beginning_tokens = tokenizer.encode(text[:start_idx].strip())
634
        last_token = tokenizer.encode(" " + last_token)
635
        return beginning_tokens, last_token
636

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.