CSS-LM

Форк
0
/
simple_lm_finetuning.py 
644 строки · 27.5 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25

26
import numpy as np
27
import torch
28
from torch.utils.data import DataLoader, Dataset, RandomSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
from pytorch_pretrained_bert.modeling import BertForPreTraining
33
from pytorch_pretrained_bert.tokenization import BertTokenizer
34
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
35

36
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
37
                    datefmt='%m/%d/%Y %H:%M:%S',
38
                    level=logging.INFO)
39
logger = logging.getLogger(__name__)
40

41

42
class BERTDataset(Dataset):
43
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
44
        self.vocab = tokenizer.vocab
45
        self.tokenizer = tokenizer
46
        self.seq_len = seq_len
47
        self.on_memory = on_memory
48
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
49
        self.corpus_path = corpus_path
50
        self.encoding = encoding
51
        self.current_doc = 0  # to avoid random sentence from same doc
52

53
        # for loading samples directly from file
54
        self.sample_counter = 0  # used to keep track of full epochs on file
55
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
56

57
        # for loading samples in memory
58
        self.current_random_doc = 0
59
        self.num_docs = 0
60
        self.sample_to_doc = [] # map sample index to doc and line
61

62
        # load samples into memory
63
        if on_memory:
64
            self.all_docs = []
65
            doc = []
66
            self.corpus_lines = 0
67
            with open(corpus_path, "r", encoding=encoding) as f:
68
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
69
                    line = line.strip()
70
                    if line == "":
71
                        self.all_docs.append(doc)
72
                        doc = []
73
                        #remove last added sample because there won't be a subsequent line anymore in the doc
74
                        self.sample_to_doc.pop()
75
                    else:
76
                        #store as one sample
77
                        sample = {"doc_id": len(self.all_docs),
78
                                  "line": len(doc)}
79
                        self.sample_to_doc.append(sample)
80
                        doc.append(line)
81
                        self.corpus_lines = self.corpus_lines + 1
82

83
            # if last row in file is not empty
84
            if self.all_docs[-1] != doc:
85
                self.all_docs.append(doc)
86
                self.sample_to_doc.pop()
87

88
            self.num_docs = len(self.all_docs)
89

90
        # load samples later lazily from disk
91
        else:
92
            if self.corpus_lines is None:
93
                with open(corpus_path, "r", encoding=encoding) as f:
94
                    self.corpus_lines = 0
95
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
96
                        if line.strip() == "":
97
                            self.num_docs += 1
98
                        else:
99
                            self.corpus_lines += 1
100

101
                    # if doc does not end with empty line
102
                    if line.strip() != "":
103
                        self.num_docs += 1
104

105
            self.file = open(corpus_path, "r", encoding=encoding)
106
            self.random_file = open(corpus_path, "r", encoding=encoding)
107

108
    def __len__(self):
109
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
110
        return self.corpus_lines - self.num_docs - 1
111

112
    def __getitem__(self, item):
113
        cur_id = self.sample_counter
114
        self.sample_counter += 1
115
        if not self.on_memory:
116
            # after one epoch we start again from beginning of file
117
            if cur_id != 0 and (cur_id % len(self) == 0):
118
                self.file.close()
119
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
120

121
        t1, t2, is_next_label = self.random_sent(item)
122

123
        # tokenize
124
        tokens_a = self.tokenizer.tokenize(t1)
125
        tokens_b = self.tokenizer.tokenize(t2)
126

127
        # combine to one sample
128
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
129

130
        # transform sample to features
131
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
132

133
        cur_tensors = (torch.tensor(cur_features.input_ids),
134
                       torch.tensor(cur_features.input_mask),
135
                       torch.tensor(cur_features.segment_ids),
136
                       torch.tensor(cur_features.lm_label_ids),
137
                       torch.tensor(cur_features.is_next))
138

139
        return cur_tensors
140

141
    def random_sent(self, index):
142
        """
143
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
144
        from one doc. With 50% the second sentence will be a random one from another doc.
145
        :param index: int, index of sample.
146
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
147
        """
148
        t1, t2 = self.get_corpus_line(index)
149
        if random.random() > 0.5:
150
            label = 0
151
        else:
152
            t2 = self.get_random_line()
153
            label = 1
154

155
        assert len(t1) > 0
156
        assert len(t2) > 0
157
        return t1, t2, label
158

159
    def get_corpus_line(self, item):
160
        """
161
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
162
        :param item: int, index of sample.
163
        :return: (str, str), two subsequent sentences from corpus
164
        """
165
        t1 = ""
166
        t2 = ""
167
        assert item < self.corpus_lines
168
        if self.on_memory:
169
            sample = self.sample_to_doc[item]
170
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
171
            t2 = self.all_docs[sample["doc_id"]][sample["line"]+1]
172
            # used later to avoid random nextSentence from same doc
173
            self.current_doc = sample["doc_id"]
174
            return t1, t2
175
        else:
176
            if self.line_buffer is None:
177
                # read first non-empty line of file
178
                while t1 == "" :
179
                    t1 = next(self.file).strip()
180
                    t2 = next(self.file).strip()
181
            else:
182
                # use t2 from previous iteration as new t1
183
                t1 = self.line_buffer
184
                t2 = next(self.file).strip()
185
                # skip empty rows that are used for separating documents and keep track of current doc id
186
                while t2 == "" or t1 == "":
187
                    t1 = next(self.file).strip()
188
                    t2 = next(self.file).strip()
189
                    self.current_doc = self.current_doc+1
190
            self.line_buffer = t2
191

192
        assert t1 != ""
193
        assert t2 != ""
194
        return t1, t2
195

196
    def get_random_line(self):
197
        """
198
        Get random line from another document for nextSentence task.
199
        :return: str, content of one line
200
        """
201
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
202
        # corpora. However, just to be careful, we try to make sure that
203
        # the random document is not the same as the document we're processing.
204
        for _ in range(10):
205
            if self.on_memory:
206
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
207
                rand_doc = self.all_docs[rand_doc_idx]
208
                line = rand_doc[random.randrange(len(rand_doc))]
209
            else:
210
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
211
                #pick random line
212
                for _ in range(rand_index):
213
                    line = self.get_next_line()
214
            #check if our picked random line is really from another doc like we want it to be
215
            if self.current_random_doc != self.current_doc:
216
                break
217
        return line
218

219
    def get_next_line(self):
220
        """ Gets next line of random_file and starts over when reaching end of file"""
221
        try:
222
            line = next(self.random_file).strip()
223
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
224
            if line == "":
225
                self.current_random_doc = self.current_random_doc + 1
226
                line = next(self.random_file).strip()
227
        except StopIteration:
228
            self.random_file.close()
229
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
230
            line = next(self.random_file).strip()
231
        return line
232

233

234
class InputExample(object):
235
    """A single training/test example for the language model."""
236

237
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
238
        """Constructs a InputExample.
239
        Args:
240
            guid: Unique id for the example.
241
            tokens_a: string. The untokenized text of the first sequence. For single
242
            sequence tasks, only this sequence must be specified.
243
            tokens_b: (Optional) string. The untokenized text of the second sequence.
244
            Only must be specified for sequence pair tasks.
245
            label: (Optional) string. The label of the example. This should be
246
            specified for train and dev examples, but not for test examples.
247
        """
248
        self.guid = guid
249
        self.tokens_a = tokens_a
250
        self.tokens_b = tokens_b
251
        self.is_next = is_next  # nextSentence
252
        self.lm_labels = lm_labels  # masked words for language model
253

254

255
class InputFeatures(object):
256
    """A single set of features of data."""
257

258
    def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids):
259
        self.input_ids = input_ids
260
        self.input_mask = input_mask
261
        self.segment_ids = segment_ids
262
        self.is_next = is_next
263
        self.lm_label_ids = lm_label_ids
264

265

266
def random_word(tokens, tokenizer):
267
    """
268
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
269
    :param tokens: list of str, tokenized sentence.
270
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
271
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
272
    """
273
    output_label = []
274

275
    for i, token in enumerate(tokens):
276
        prob = random.random()
277
        # mask token with 15% probability
278
        if prob < 0.15:
279
            prob /= 0.15
280

281
            # 80% randomly change token to mask token
282
            if prob < 0.8:
283
                tokens[i] = "[MASK]"
284

285
            # 10% randomly change token to random token
286
            elif prob < 0.9:
287
                tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
288

289
            # -> rest 10% randomly keep current token
290

291
            # append current token to output (we will predict these later)
292
            try:
293
                output_label.append(tokenizer.vocab[token])
294
            except KeyError:
295
                # For unknown words (should not occur with BPE vocab)
296
                output_label.append(tokenizer.vocab["[UNK]"])
297
                logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
298
        else:
299
            # no masking token (will be ignored by loss function later)
300
            output_label.append(-1)
301

302
    return tokens, output_label
303

304

305
def convert_example_to_features(example, max_seq_length, tokenizer):
306
    """
307
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
308
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
309
    :param example: InputExample, containing sentence input as strings and is_next label
310
    :param max_seq_length: int, maximum length of sequence.
311
    :param tokenizer: Tokenizer
312
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
313
    """
314
    tokens_a = example.tokens_a
315
    tokens_b = example.tokens_b
316
    # Modifies `tokens_a` and `tokens_b` in place so that the total
317
    # length is less than the specified length.
318
    # Account for [CLS], [SEP], [SEP] with "- 3"
319
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
320

321
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
322
    tokens_b, t2_label = random_word(tokens_b, tokenizer)
323
    # concatenate lm labels and account for CLS, SEP, SEP
324
    lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
325

326
    # The convention in BERT is:
327
    # (a) For sequence pairs:
328
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
329
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
330
    # (b) For single sequences:
331
    #  tokens:   [CLS] the dog is hairy . [SEP]
332
    #  type_ids: 0   0   0   0  0     0 0
333
    #
334
    # Where "type_ids" are used to indicate whether this is the first
335
    # sequence or the second sequence. The embedding vectors for `type=0` and
336
    # `type=1` were learned during pre-training and are added to the wordpiece
337
    # embedding vector (and position vector). This is not *strictly* necessary
338
    # since the [SEP] token unambigiously separates the sequences, but it makes
339
    # it easier for the model to learn the concept of sequences.
340
    #
341
    # For classification tasks, the first vector (corresponding to [CLS]) is
342
    # used as as the "sentence vector". Note that this only makes sense because
343
    # the entire model is fine-tuned.
344
    tokens = []
345
    segment_ids = []
346
    tokens.append("[CLS]")
347
    segment_ids.append(0)
348
    for token in tokens_a:
349
        tokens.append(token)
350
        segment_ids.append(0)
351
    tokens.append("[SEP]")
352
    segment_ids.append(0)
353

354
    assert len(tokens_b) > 0
355
    for token in tokens_b:
356
        tokens.append(token)
357
        segment_ids.append(1)
358
    tokens.append("[SEP]")
359
    segment_ids.append(1)
360

361
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
362

363
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
364
    # tokens are attended to.
365
    input_mask = [1] * len(input_ids)
366

367
    # Zero-pad up to the sequence length.
368
    while len(input_ids) < max_seq_length:
369
        input_ids.append(0)
370
        input_mask.append(0)
371
        segment_ids.append(0)
372
        lm_label_ids.append(-1)
373

374
    assert len(input_ids) == max_seq_length
375
    assert len(input_mask) == max_seq_length
376
    assert len(segment_ids) == max_seq_length
377
    assert len(lm_label_ids) == max_seq_length
378

379
    if example.guid < 5:
380
        logger.info("*** Example ***")
381
        logger.info("guid: %s" % (example.guid))
382
        logger.info("tokens: %s" % " ".join(
383
                [str(x) for x in tokens]))
384
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
385
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
386
        logger.info(
387
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
388
        logger.info("LM label: %s " % (lm_label_ids))
389
        logger.info("Is next sentence label: %s " % (example.is_next))
390

391
    features = InputFeatures(input_ids=input_ids,
392
                             input_mask=input_mask,
393
                             segment_ids=segment_ids,
394
                             lm_label_ids=lm_label_ids,
395
                             is_next=example.is_next)
396
    return features
397

398

399
def main():
400
    parser = argparse.ArgumentParser()
401

402
    ## Required parameters
403
    parser.add_argument("--train_corpus",
404
                        default=None,
405
                        type=str,
406
                        required=True,
407
                        help="The input train corpus.")
408
    parser.add_argument("--bert_model", default=None, type=str, required=True,
409
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
410
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
411
    parser.add_argument("--output_dir",
412
                        default=None,
413
                        type=str,
414
                        required=True,
415
                        help="The output directory where the model checkpoints will be written.")
416

417
    ## Other parameters
418
    parser.add_argument("--max_seq_length",
419
                        default=128,
420
                        type=int,
421
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
422
                             "Sequences longer than this will be truncated, and sequences shorter \n"
423
                             "than this will be padded.")
424
    parser.add_argument("--do_train",
425
                        action='store_true',
426
                        help="Whether to run training.")
427
    parser.add_argument("--train_batch_size",
428
                        default=32,
429
                        type=int,
430
                        help="Total batch size for training.")
431
    parser.add_argument("--learning_rate",
432
                        default=3e-5,
433
                        type=float,
434
                        help="The initial learning rate for Adam.")
435
    parser.add_argument("--num_train_epochs",
436
                        default=3.0,
437
                        type=float,
438
                        help="Total number of training epochs to perform.")
439
    parser.add_argument("--warmup_proportion",
440
                        default=0.1,
441
                        type=float,
442
                        help="Proportion of training to perform linear learning rate warmup for. "
443
                             "E.g., 0.1 = 10%% of training.")
444
    parser.add_argument("--no_cuda",
445
                        action='store_true',
446
                        help="Whether not to use CUDA when available")
447
    parser.add_argument("--on_memory",
448
                        action='store_true',
449
                        help="Whether to load train samples into memory or use disk")
450
    parser.add_argument("--do_lower_case",
451
                        action='store_true',
452
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
453
    parser.add_argument("--local_rank",
454
                        type=int,
455
                        default=-1,
456
                        help="local_rank for distributed training on gpus")
457
    parser.add_argument('--seed',
458
                        type=int,
459
                        default=42,
460
                        help="random seed for initialization")
461
    parser.add_argument('--gradient_accumulation_steps',
462
                        type=int,
463
                        default=1,
464
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
465
    parser.add_argument('--fp16',
466
                        action='store_true',
467
                        help="Whether to use 16-bit float precision instead of 32-bit")
468
    parser.add_argument('--loss_scale',
469
                        type = float, default = 0,
470
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
471
                        "0 (default value): dynamic loss scaling.\n"
472
                        "Positive power of 2: static loss scaling value.\n")
473

474
    args = parser.parse_args()
475

476
    if args.local_rank == -1 or args.no_cuda:
477
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
478
        n_gpu = torch.cuda.device_count()
479
    else:
480
        torch.cuda.set_device(args.local_rank)
481
        device = torch.device("cuda", args.local_rank)
482
        n_gpu = 1
483
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
484
        torch.distributed.init_process_group(backend='nccl')
485
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
486
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
487

488
    if args.gradient_accumulation_steps < 1:
489
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
490
                            args.gradient_accumulation_steps))
491

492
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
493

494
    random.seed(args.seed)
495
    np.random.seed(args.seed)
496
    torch.manual_seed(args.seed)
497
    if n_gpu > 0:
498
        torch.cuda.manual_seed_all(args.seed)
499

500
    if not args.do_train:
501
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
502

503
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
504
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
505
    if not os.path.exists(args.output_dir):
506
        os.makedirs(args.output_dir)
507

508
    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
509

510
    #train_examples = None
511
    num_train_optimization_steps = None
512
    if args.do_train:
513
        print("Loading Train Dataset", args.train_corpus)
514
        train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
515
                                    corpus_lines=None, on_memory=args.on_memory)
516
        num_train_optimization_steps = int(
517
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
518
        if args.local_rank != -1:
519
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
520

521
    # Prepare model
522
    model = BertForPreTraining.from_pretrained(args.bert_model)
523
    if args.fp16:
524
        model.half()
525
    model.to(device)
526
    if args.local_rank != -1:
527
        try:
528
            from apex.parallel import DistributedDataParallel as DDP
529
        except ImportError:
530
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
531
        model = DDP(model)
532
    elif n_gpu > 1:
533
        model = torch.nn.DataParallel(model)
534

535
    # Prepare optimizer
536
    if args.do_train:
537
        param_optimizer = list(model.named_parameters())
538
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
539
        optimizer_grouped_parameters = [
540
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
541
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
542
            ]
543

544
        if args.fp16:
545
            try:
546
                from apex.optimizers import FP16_Optimizer
547
                from apex.optimizers import FusedAdam
548
            except ImportError:
549
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
550

551
            optimizer = FusedAdam(optimizer_grouped_parameters,
552
                                  lr=args.learning_rate,
553
                                  bias_correction=False,
554
                                  max_grad_norm=1.0)
555
            if args.loss_scale == 0:
556
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
557
            else:
558
                optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
559
            warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
560
                                                 t_total=num_train_optimization_steps)
561

562
        else:
563
            optimizer = BertAdam(optimizer_grouped_parameters,
564
                                 lr=args.learning_rate,
565
                                 warmup=args.warmup_proportion,
566
                                 t_total=num_train_optimization_steps)
567

568
    global_step = 0
569
    if args.do_train:
570
        logger.info("***** Running training *****")
571
        logger.info("  Num examples = %d", len(train_dataset))
572
        logger.info("  Batch size = %d", args.train_batch_size)
573
        logger.info("  Num steps = %d", num_train_optimization_steps)
574

575
        if args.local_rank == -1:
576
            train_sampler = RandomSampler(train_dataset)
577
        else:
578
            #TODO: check if this works with current data generator from disk that relies on next(file)
579
            # (it doesn't return item back by index)
580
            train_sampler = DistributedSampler(train_dataset)
581
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
582

583
        model.train()
584
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
585
            tr_loss = 0
586
            nb_tr_examples, nb_tr_steps = 0, 0
587
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
588
                batch = tuple(t.to(device) for t in batch)
589
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
590
                loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
591
                if n_gpu > 1:
592
                    loss = loss.mean() # mean() to average on multi-gpu.
593
                if args.gradient_accumulation_steps > 1:
594
                    loss = loss / args.gradient_accumulation_steps
595
                if args.fp16:
596
                    optimizer.backward(loss)
597
                else:
598
                    loss.backward()
599
                tr_loss += loss.item()
600
                nb_tr_examples += input_ids.size(0)
601
                nb_tr_steps += 1
602
                if (step + 1) % args.gradient_accumulation_steps == 0:
603
                    if args.fp16:
604
                        # modify learning rate with special warm up BERT uses
605
                        # if args.fp16 is False, BertAdam is used that handles this automatically
606
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
607
                        for param_group in optimizer.param_groups:
608
                            param_group['lr'] = lr_this_step
609
                    optimizer.step()
610
                    optimizer.zero_grad()
611
                    global_step += 1
612

613
        # Save a trained model
614
        logger.info("** ** * Saving fine - tuned model ** ** * ")
615
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
616
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
617
        if args.do_train:
618
            torch.save(model_to_save.state_dict(), output_model_file)
619

620

621
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
622
    """Truncates a sequence pair in place to the maximum length."""
623

624
    # This is a simple heuristic which will always truncate the longer sequence
625
    # one token at a time. This makes more sense than truncating an equal percent
626
    # of tokens from each, since if one sequence is very short then each token
627
    # that's truncated likely contains more information than a longer sequence.
628
    while True:
629
        total_length = len(tokens_a) + len(tokens_b)
630
        if total_length <= max_length:
631
            break
632
        if len(tokens_a) > len(tokens_b):
633
            tokens_a.pop()
634
        else:
635
            tokens_b.pop()
636

637

638
def accuracy(out, labels):
639
    outputs = np.argmax(out, axis=1)
640
    return np.sum(outputs == labels)
641

642

643
if __name__ == "__main__":
644
    main()
645

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.