CSS-LM

Форк
0
/
pretrain_roberta_including_Preprocess.py 
752 строки · 31.1 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25

26
import numpy as np
27
import torch
28
from torch.utils.data import DataLoader, Dataset, RandomSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
33
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
34

35
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
36
                    datefmt='%m/%d/%Y %H:%M:%S',
37
                    level=logging.INFO)
38
logger = logging.getLogger(__name__)
39

40
class Dataset_noNext(Dataset):
41
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
42

43
        self.vocab_size = tokenizer.vocab_size
44
        self.tokenizer = tokenizer
45
        self.seq_len = seq_len
46
        self.on_memory = on_memory
47
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
48
        self.corpus_path = corpus_path
49
        self.encoding = encoding
50
        self.current_doc = 0  # to avoid random sentence from same doc
51

52
        # for loading samples directly from file
53
        self.sample_counter = 0  # used to keep track of full epochs on file
54
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
55

56
        # for loading samples in memory
57
        self.current_random_doc = 0
58
        self.num_docs = 0
59
        self.sample_to_doc = [] # map sample index to doc and line
60

61
        # load samples into memory
62
        if on_memory:
63
            self.all_docs = []
64
            doc = []
65
            self.corpus_lines = 0
66
            with open(corpus_path, "r", encoding=encoding) as f:
67
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
68
                    line = line.strip()
69
                    if line == "":
70
                        self.all_docs.append(doc)
71
                        doc = []
72
                        #remove last added sample because there won't be a subsequent line anymore in the doc
73
                        self.sample_to_doc.pop()
74
                    else:
75
                        #store as one sample
76
                        sample = {"doc_id": len(self.all_docs),
77
                                  "line": len(doc)}
78
                        self.sample_to_doc.append(sample)
79
                        doc.append(line)
80
                        self.corpus_lines = self.corpus_lines + 1
81

82
            # if last row in file is not empty
83
            if self.all_docs[-1] != doc:
84
                self.all_docs.append(doc)
85
                self.sample_to_doc.pop()
86

87
            self.num_docs = len(self.all_docs)
88

89
        # load samples later lazily from disk
90
        else:
91
            if self.corpus_lines is None:
92
                with open(corpus_path, "r", encoding=encoding) as f:
93
                    self.corpus_lines = 0
94
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
95
                        if line.strip() == "":
96
                            self.num_docs += 1
97
                        else:
98
                            self.corpus_lines += 1
99

100
                    # if doc does not end with empty line
101
                    if line.strip() != "":
102
                        self.num_docs += 1
103

104
            self.file = open(corpus_path, "r", encoding=encoding)
105
            self.random_file = open(corpus_path, "r", encoding=encoding)
106

107
    def __len__(self):
108
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
109
        return self.corpus_lines - self.num_docs - 1
110

111
    def __getitem__(self, item):
112
        cur_id = self.sample_counter
113
        self.sample_counter += 1
114
        if not self.on_memory:
115
            # after one epoch we start again from beginning of file
116
            if cur_id != 0 and (cur_id % len(self) == 0):
117
                self.file.close()
118
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
119

120
        #t1, t2, is_next_label = self.random_sent(item)
121
        t1, is_next_label = self.random_sent(item)
122
        if is_next_label == None:
123
            is_next_label = 0
124

125
        tokens_a = self.tokenizer.tokenize(t1)
126
        #tokens_b = self.tokenizer.tokenize(t2)
127

128

129
        # tokenize
130
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
131

132
        # transform sample to features
133
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
134

135
        cur_tensors = (torch.tensor(cur_features.input_ids),
136
                       torch.tensor(cur_features.input_mask),
137
                       torch.tensor(cur_features.segment_ids),
138
                       torch.tensor(cur_features.lm_label_ids),
139
                       torch.tensor(cur_features.is_next))
140

141
        return cur_tensors
142

143
    def random_sent(self, index):
144
        """
145
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
146
        from one doc. With 50% the second sentence will be a random one from another doc.
147
        :param index: int, index of sample.
148
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
149
        """
150
        t1, t2 = self.get_corpus_line(index)
151
        return t1, None
152

153
    def get_corpus_line(self, item):
154
        """
155
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
156
        :param item: int, index of sample.
157
        :return: (str, str), two subsequent sentences from corpus
158
        """
159
        t1 = ""
160
        t2 = ""
161
        assert item < self.corpus_lines
162
        if self.on_memory:
163
            sample = self.sample_to_doc[item]
164
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
165
            # used later to avoid random nextSentence from same doc
166
            self.current_doc = sample["doc_id"]
167
            return t1, t2
168
            #return t1
169
        else:
170
            if self.line_buffer is None:
171
                # read first non-empty line of file
172
                while t1 == "" :
173
                    t1 = next(self.file).strip()
174
            else:
175
                # use t2 from previous iteration as new t1
176
                t1 = self.line_buffer
177
                # skip empty rows that are used for separating documents and keep track of current doc id
178
                while t1 == "":
179
                    t1 = next(self.file).strip()
180
                    self.current_doc = self.current_doc+1
181
            self.line_buffer = next(self.file).strip()
182

183
        assert t1 != ""
184
        return t1, t2
185

186

187
    def get_random_line(self):
188
        """
189
        Get random line from another document for nextSentence task.
190
        :return: str, content of one line
191
        """
192
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
193
        # corpora. However, just to be careful, we try to make sure that
194
        # the random document is not the same as the document we're processing.
195
        for _ in range(10):
196
            if self.on_memory:
197
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
198
                rand_doc = self.all_docs[rand_doc_idx]
199
                line = rand_doc[random.randrange(len(rand_doc))]
200
            else:
201
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
202
                #pick random line
203
                for _ in range(rand_index):
204
                    line = self.get_next_line()
205
            #check if our picked random line is really from another doc like we want it to be
206
            if self.current_random_doc != self.current_doc:
207
                break
208
        return line
209

210
    def get_next_line(self):
211
        """ Gets next line of random_file and starts over when reaching end of file"""
212
        try:
213
            line = next(self.random_file).strip()
214
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
215
            if line == "":
216
                self.current_random_doc = self.current_random_doc + 1
217
                line = next(self.random_file).strip()
218
        except StopIteration:
219
            self.random_file.close()
220
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
221
            line = next(self.random_file).strip()
222
        return line
223

224

225
class InputExample(object):
226
    """A single training/test example for the language model."""
227

228
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
229
        """Constructs a InputExample.
230
        Args:
231
            guid: Unique id for the example.
232
            tokens_a: string. The untokenized text of the first sequence. For single
233
            sequence tasks, only this sequence must be specified.
234
            tokens_b: (Optional) string. The untokenized text of the second sequence.
235
            Only must be specified for sequence pair tasks.
236
            label: (Optional) string. The label of the example. This should be
237
            specified for train and dev examples, but not for test examples.
238
        """
239
        self.guid = guid
240
        self.tokens_a = tokens_a
241
        self.tokens_b = tokens_b
242
        self.is_next = is_next  # nextSentence
243
        self.lm_labels = lm_labels  # masked words for language model
244

245

246
class InputFeatures(object):
247
    """A single set of features of data."""
248

249
    def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids):
250
        self.input_ids = input_ids
251
        self.input_mask = input_mask
252
        self.segment_ids = segment_ids
253
        self.is_next = is_next
254
        self.lm_label_ids = lm_label_ids
255

256

257
def random_word(tokens, tokenizer):
258
    """
259
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
260
    :param tokens: list of str, tokenized sentence.
261
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
262
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
263
    """
264
    output_label = []
265

266
    #print("========")
267
    #print(tokens)
268
    #print("---")
269
    for i, token in enumerate(tokens):
270
        '''
271
        print("========")
272
        print("========")
273
        print(tokens)
274
        print("---")
275
        print(token)
276
        print("---")
277
        print(tokenizer.decode(random.randint(0,tokenizer.vocab_size)))
278
        print(random.randint(0,tokenizer.vocab_size))
279
        print(type(tokenizer.decode(random.randint(0,tokenizer.vocab_size))))
280
        print("========")
281
        print("========")
282
        exit()
283
        '''
284

285
        prob = random.random()
286
        # mask token with 15% probability
287
        if prob < 0.15:
288
            prob /= 0.15
289
            #print(tokenizer.vocab)
290
            #exit()
291
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
292

293
            '''
294
            print(tokenizer.convert_ids_to_tokens(candidate_id))
295
            print("++++")
296
            print(candidate_id)
297
            print(tokenizer.encode(tokens))
298
            a=tokenizer.encode(tokens)
299
            print(tokenizer.decode(a))
300
            print(tokenizer.convert_ids_to_tokens(candidate_id))
301
            print("++++")
302
            exit()
303
            '''
304

305
            # 80% randomly change token to mask token
306
            if prob < 0.8:
307
                tokens[i] = "<mask>"
308

309
            # 10% randomly change token to random token
310
            elif prob < 0.9:
311
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
312
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
313
                candidate_id = random.randint(0,tokenizer.vocab_size)
314
                w = tokenizer.convert_ids_to_tokens(candidate_id)
315
                '''
316
                if tokens[i] == None:
317
                    candidate_id = 100
318
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
319
                '''
320
                tokens[i] = w
321

322

323
            # -> rest 10% randomly keep current token
324

325
            # append current token to output (we will predict these later)
326
            try:
327
                #output_label.append(tokenizer.vocab[token])
328
                w = tokenizer.convert_tokens_to_ids(token)
329
                if w!= None:
330
                    output_label.append(w)
331
                else:
332
                    print("Have no this tokens in ids")
333
                    exit()
334
            except KeyError:
335
                # For unknown words (should not occur with BPE vocab)
336
                #output_label.append(tokenizer.vocab["<unk>"])
337
                w = tokenizer.convert_tokens_to_ids("<unk>")
338
                output_label.append(w)
339
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
340
        else:
341
            # no masking token (will be ignored by loss function later)
342
            output_label.append(-1)
343

344
    #print(tokens)
345
    #print("========")
346
    #exit()
347

348
    return tokens, output_label
349

350

351
def convert_example_to_features(example, max_seq_length, tokenizer):
352
    """
353
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
354
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
355
    :param example: InputExample, containing sentence input as strings and is_next label
356
    :param max_seq_length: int, maximum length of sequence.
357
    :param tokenizer: Tokenizer
358
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
359
    """
360
    tokens_a = example.tokens_a
361
    tokens_b = example.tokens_b
362
    # Modifies `tokens_a` and `tokens_b` in place so that the total
363
    # length is less than the specified length.
364
    # Account for [CLS], [SEP], [SEP] with "- 3"
365
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
366
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
367

368
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
369
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
370
    # concatenate lm labels and account for CLS, SEP, SEP
371
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
372
    lm_label_ids = ([-1] + t1_label + [-1])
373

374
    # The convention in BERT is:
375
    # (a) For sequence pairs:
376
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
377
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
378
    # (b) For single sequences:
379
    #  tokens:   [CLS] the dog is hairy . [SEP]
380
    #  type_ids: 0   0   0   0  0     0 0
381
    #
382
    # Where "type_ids" are used to indicate whether this is the first
383
    # sequence or the second sequence. The embedding vectors for `type=0` and
384
    # `type=1` were learned during pre-training and are added to the wordpiece
385
    # embedding vector (and position vector). This is not *strictly* necessary
386
    # since the [SEP] token unambigiously separates the sequences, but it makes
387
    # it easier for the model to learn the concept of sequences.
388
    #
389
    # For classification tasks, the first vector (corresponding to [CLS]) is
390
    # used as as the "sentence vector". Note that this only makes sense because
391
    # the entire model is fine-tuned.
392
    tokens = []
393
    segment_ids = []
394
    tokens.append("<s>")
395
    segment_ids.append(0)
396
    for token in tokens_a:
397
        tokens.append(token)
398
        segment_ids.append(0)
399
    tokens.append("</s>")
400
    segment_ids.append(0)
401

402
    '''
403
    assert len(tokens_b) > 0
404
    for token in tokens_b:
405
        tokens.append(token)
406
        segment_ids.append(1)
407
    '''
408
    #tokens.append("[SEP]")
409
    #segment_ids.append(1)
410

411
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
412
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
413

414
    #print(input_ids)
415
    input_ids = [w if w!=None else 0 for w in input_ids]
416
    #print(input_ids)
417
    #exit()
418

419
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
420
    # tokens are attended to.
421
    input_mask = [1] * len(input_ids)
422

423
    # Zero-pad up to the sequence length.
424
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
425

426
    while len(input_ids) < max_seq_length:
427
        input_ids.append(pad_id)
428
        input_mask.append(0)
429
        segment_ids.append(0)
430
        lm_label_ids.append(-1)
431

432
    assert len(input_ids) == max_seq_length
433
    assert len(input_mask) == max_seq_length
434
    assert len(segment_ids) == max_seq_length
435
    assert len(lm_label_ids) == max_seq_length
436

437
    if example.guid < 5:
438
        logger.info("*** Example ***")
439
        logger.info("guid: %s" % (example.guid))
440
        logger.info("tokens: %s" % " ".join(
441
                [str(x) for x in tokens]))
442
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
443
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
444
        logger.info(
445
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
446
        logger.info("LM label: %s " % (lm_label_ids))
447
        logger.info("Is next sentence label: %s " % (example.is_next))
448

449
    features = InputFeatures(input_ids=input_ids,
450
                             input_mask=input_mask,
451
                             segment_ids=segment_ids,
452
                             lm_label_ids=lm_label_ids,
453
                             is_next=example.is_next)
454
    return features
455

456

457
def main():
458
    parser = argparse.ArgumentParser()
459

460
    ## Required parameters
461
    parser.add_argument("--data_dir",
462
                        default=None,
463
                        type=str,
464
                        required=True,
465
                        help="The input train corpus.")
466
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
467
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
468
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
469
    parser.add_argument("--output_dir",
470
                        default=None,
471
                        type=str,
472
                        required=True,
473
                        help="The output directory where the model checkpoints will be written.")
474

475
    ## Other parameters
476
    parser.add_argument("--max_seq_length",
477
                        default=128,
478
                        type=int,
479
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
480
                             "Sequences longer than this will be truncated, and sequences shorter \n"
481
                             "than this will be padded.")
482
    parser.add_argument("--do_train",
483
                        action='store_true',
484
                        help="Whether to run training.")
485
    parser.add_argument("--train_batch_size",
486
                        default=32,
487
                        type=int,
488
                        help="Total batch size for training.")
489
    parser.add_argument("--learning_rate",
490
                        default=3e-5,
491
                        type=float,
492
                        help="The initial learning rate for Adam.")
493
    parser.add_argument("--num_train_epochs",
494
                        default=3.0,
495
                        type=float,
496
                        help="Total number of training epochs to perform.")
497
    parser.add_argument("--warmup_proportion",
498
                        default=0.1,
499
                        type=float,
500
                        help="Proportion of training to perform linear learning rate warmup for. "
501
                             "E.g., 0.1 = 10%% of training.")
502
    parser.add_argument("--no_cuda",
503
                        action='store_true',
504
                        help="Whether not to use CUDA when available")
505
    parser.add_argument("--on_memory",
506
                        action='store_true',
507
                        help="Whether to load train samples into memory or use disk")
508
    parser.add_argument("--do_lower_case",
509
                        action='store_true',
510
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
511
    parser.add_argument("--local_rank",
512
                        type=int,
513
                        default=-1,
514
                        help="local_rank for distributed training on gpus")
515
    parser.add_argument('--seed',
516
                        type=int,
517
                        default=42,
518
                        help="random seed for initialization")
519
    parser.add_argument('--gradient_accumulation_steps',
520
                        type=int,
521
                        default=1,
522
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
523
    parser.add_argument('--fp16',
524
                        action='store_true',
525
                        help="Whether to use 16-bit float precision instead of 32-bit")
526
    parser.add_argument('--loss_scale',
527
                        type = float, default = 0,
528
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
529
                        "0 (default value): dynamic loss scaling.\n"
530
                        "Positive power of 2: static loss scaling value.\n")
531
    ####
532
    parser.add_argument("--num_labels_task",
533
                        default=None, type=int,
534
                        required=True,
535
                        help="num_labels_task")
536
    parser.add_argument("--weight_decay",
537
                        default=0.0,
538
                        type=float,
539
                        help="Weight decay if we apply some.")
540
    parser.add_argument("--adam_epsilon",
541
                        default=1e-8,
542
                        type=float,
543
                        help="Epsilon for Adam optimizer.")
544
    parser.add_argument("--max_grad_norm",
545
                        default=1.0,
546
                        type=float,
547
                        help="Max gradient norm.")
548
    parser.add_argument('--fp16_opt_level',
549
                        type=str,
550
                        default='O1',
551
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
552
                             "See details at https://nvidia.github.io/apex/amp.html")
553
    parser.add_argument("--task",
554
                        default=None,
555
                        type=int,
556
                        required=True,
557
                        help="Choose Task")
558
    ####
559

560
    args = parser.parse_args()
561

562
    if args.local_rank == -1 or args.no_cuda:
563
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
564
        n_gpu = torch.cuda.device_count()
565
    else:
566
        torch.cuda.set_device(args.local_rank)
567
        device = torch.device("cuda", args.local_rank)
568
        n_gpu = 1
569
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
570
        torch.distributed.init_process_group(backend='nccl')
571
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
572
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
573

574
    if args.gradient_accumulation_steps < 1:
575
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
576
                            args.gradient_accumulation_steps))
577

578
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
579

580
    random.seed(args.seed)
581
    np.random.seed(args.seed)
582
    torch.manual_seed(args.seed)
583
    if n_gpu > 0:
584
        torch.cuda.manual_seed_all(args.seed)
585

586
    if not args.do_train:
587
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
588

589
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
590
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
591
    if not os.path.exists(args.output_dir):
592
        os.makedirs(args.output_dir)
593

594
    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
595

596

597
    #train_examples = None
598
    num_train_optimization_steps = None
599
    if args.do_train:
600
        print("Loading Train Dataset", args.data_dir)
601
        train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length,
602
                                    corpus_lines=None, on_memory=args.on_memory)
603
        num_train_optimization_steps = int(
604
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
605
        if args.local_rank != -1:
606
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
607

608
    # Prepare model
609
    model = RobertaForMaskedLM.from_pretrained(args.pretrain_model)
610
    model.to(device)
611

612
    # Prepare optimizer
613
    if args.do_train:
614
        param_optimizer = list(model.named_parameters())
615
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
616
        optimizer_grouped_parameters = [
617
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
618
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
619
            ]
620
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
621
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
622

623
        if args.fp16:
624
            try:
625
                from apex import amp
626
            except ImportError:
627
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
628
                exit()
629

630
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
631

632

633
        if n_gpu > 1:
634
            model = torch.nn.DataParallel(model)
635

636
        if args.local_rank != -1:
637
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
638

639

640
    global_step = 0
641
    if args.do_train:
642
        logger.info("***** Running training *****")
643
        logger.info("  Num examples = %d", len(train_dataset))
644
        logger.info("  Batch size = %d", args.train_batch_size)
645
        logger.info("  Num steps = %d", num_train_optimization_steps)
646

647
        if args.local_rank == -1:
648
            train_sampler = RandomSampler(train_dataset)
649
        else:
650
            #TODO: check if this works with current data generator from disk that relies on next(file)
651
            # (it doesn't return item back by index)
652
            train_sampler = DistributedSampler(train_dataset)
653
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
654

655
        output_loss_file = os.path.join(args.output_dir, "loss")
656
        loss_fout = open(output_loss_file, 'w')
657
        model.train()
658
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
659
            tr_loss = 0
660
            nb_tr_examples, nb_tr_steps = 0, 0
661
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
662
                batch = tuple(t.to(device) for t in batch)
663
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
664
                '''
665
                print("=======")
666
                print(input_ids)
667
                for i in range(len(input_ids)):
668
                    print("---")
669
                    print(len(input_ids[i]))
670
                    print(tokenizer.decode(input_ids[i]))
671
                    print("---")
672
                exit()
673
                '''
674

675
                output = model(input_ids=input_ids, masked_lm_labels=lm_label_ids, return_dict=True)
676
                loss = output.loss
677

678
                if n_gpu > 1:
679
                    loss = loss.mean() # mean() to average on multi-gpu.
680
                if args.gradient_accumulation_steps > 1:
681
                    loss = loss / args.gradient_accumulation_steps
682
                if args.fp16:
683
                    #optimizer.backward(loss)
684
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
685
                        scaled_loss.backward()
686
                else:
687
                    loss.backward()
688

689
                ###
690
                loss_fout.write("{}\n".format(loss.item()))
691
                ###
692

693
                tr_loss += loss.item()
694
                nb_tr_examples += input_ids.size(0)
695
                nb_tr_steps += 1
696
                if (step + 1) % args.gradient_accumulation_steps == 0:
697
                    if args.fp16:
698
                        # modify learning rate with special warm up BERT uses
699
                        # if args.fp16 is False, BertAdam is used that handles this automatically
700
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
701
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
702
                    ###
703
                    else:
704
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
705
                    ###
706

707
                    optimizer.step()
708
                    ###
709
                    scheduler.step()
710
                    ###
711
                    #optimizer.zero_grad()
712
                    model.zero_grad()
713
                    global_step += 1
714

715
            #if global_step % 100000 == 0:
716
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
717
            output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
718
            torch.save(model_to_save.state_dict(), output_model_file)
719

720
        loss_fout.close()
721

722
        # Save a trained model
723
        logger.info("** ** * Saving fine - tuned model ** ** * ")
724
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
725
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
726
        if args.do_train:
727
            torch.save(model_to_save.state_dict(), output_model_file)
728

729

730
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
731
    """Truncates a sequence pair in place to the maximum length."""
732

733
    # This is a simple heuristic which will always truncate the longer sequence
734
    # one token at a time. This makes more sense than truncating an equal percent
735
    # of tokens from each, since if one sequence is very short then each token
736
    # that's truncated likely contains more information than a longer sequence.
737
    while True:
738
        #total_length = len(tokens_a) + len(tokens_b)
739
        total_length = len(tokens_a)
740
        if total_length <= max_length:
741
            break
742
        else:
743
            tokens_a.pop()
744

745

746
def accuracy(out, labels):
747
    outputs = np.argmax(out, axis=1)
748
    return np.sum(outputs == labels)
749

750

751
if __name__ == "__main__":
752
    main()
753

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.